diff --git a/docs/dialect/aggregation-operators.rst b/docs/dialect/aggregation-operators.rst index 9887b87..a83e1c6 100644 --- a/docs/dialect/aggregation-operators.rst +++ b/docs/dialect/aggregation-operators.rst @@ -328,4 +328,128 @@ Related Operators ~~~~~~~~~~~~~~~~~ - :ref:`CLUSTER ` - Assign cluster IDs without merging +- :ref:`RASTERIZE ` - Rasterize intervals onto a fixed bin grid - :ref:`INTERSECTS ` - Test for overlap between specific pairs + +---- + +.. _rasterize-operator: + +RASTERIZE +--------- + +Rasterize interval data onto a fixed-resolution bin grid, counting overlaps per bin. + +Description +~~~~~~~~~~~ + +The ``RASTERIZE`` operator tiles the genome into fixed-width bins and counts the number of intervals overlapping each bin. It generates a bin grid using ``generate_series`` and joins it against the source table to count overlapping features per bin. + +This is useful for: + +- Summarising feature density at a user-defined resolution +- Creating fixed-resolution count tracks from interval data +- Quick visualisation of interval pile-ups across the genome + +An interval that spans multiple bins is counted in each of the bins it overlaps, matching the ``bedtools coverage`` convention. As a result, the sum of bin counts is generally greater than the number of source intervals — bin counts answer "how many intervals touch this bin?", not "how are intervals partitioned across bins?". + +The operator works as an aggregate function, returning one row per bin with the bin coordinates and the count. + +.. note:: + + RASTERIZE depends on ``LATERAL`` plus ``generate_series`` for bin generation, which DuckDB and PostgreSQL both support. SQLite does not currently provide either primitive, so this operator is not yet available on the SQLite backend. + +.. note:: + + Only the ``count`` aggregation is supported in this release. Weighted summary statistics (mean, sum, min, max) over interval values raise non-trivial semantic questions when intervals span bin boundaries (full-value contribution vs. length-weighted vs. per-base depth) and are tracked as a follow-up. + +Syntax +~~~~~~ + +.. code-block:: sql + + -- Count overlapping intervals per bin + SELECT RASTERIZE(interval, ) FROM features + + -- Named resolution parameter + SELECT RASTERIZE(interval, resolution := 500) FROM features + +Parameters +~~~~~~~~~~ + +**interval** + A genomic column. + +**resolution** *(required)* + Bin width in base pairs — must be a positive integer literal. Can be given as a positional or named parameter (``RASTERIZE(interval, 1000)`` or ``RASTERIZE(interval, resolution := 1000)``). Omitting it, or supplying a non-positive value, raises ``ValueError`` at transpile time. + +Return Value +~~~~~~~~~~~~ + +Returns one row per genomic bin: + +- ``chrom`` — Chromosome of the bin +- ``start`` — Start position of the bin +- ``end`` — End position of the bin +- ``value`` — The count of intervals overlapping the bin (default alias; use ``AS`` to rename) + +Examples +~~~~~~~~ + +**Basic Count:** + +Count the number of features overlapping each 1 kb bin: + +.. code-block:: sql + + SELECT RASTERIZE(interval, 1000) + FROM features + +**Named Alias:** + +.. code-block:: sql + + SELECT RASTERIZE(interval, 1000) AS depth + FROM reads + +**With WHERE Filter:** + +Assuming the source table includes a ``score`` column, count high-scoring features per bin: + +.. code-block:: sql + + SELECT RASTERIZE(interval, 1000) AS depth + FROM features + WHERE score > 10 + +Supported FROM clauses +~~~~~~~~~~~~~~~~~~~~~~ + +``RASTERIZE`` requires a ``FROM`` clause that references a table or named CTE. Inline subqueries (``FROM (SELECT ...) AS sub``) and ``VALUES`` clauses are not supported — wrap the derivation in a ``WITH`` clause and select ``RASTERIZE(...)`` from the CTE by name: + +.. code-block:: sql + + -- Not supported: inline subquery in FROM + SELECT RASTERIZE(interval, 1000) + FROM (SELECT * FROM features WHERE score > 50) AS filtered + + -- Supported: same derivation wrapped in a CTE + WITH filtered AS ( + SELECT * FROM features WHERE score > 50 + ) + SELECT RASTERIZE(interval, 1000) FROM filtered + +Any ``WITH`` clauses you declare are preserved alongside the internal ``__giql_bins`` CTE in the transpiled SQL. + +Performance Notes +~~~~~~~~~~~~~~~~~ + +- The operator creates one bin per chromosome per step, so smaller resolutions produce more rows +- A ``LEFT JOIN`` ensures bins with zero coverage are included in the output +- For very large genomes, consider restricting the query with a ``WHERE`` clause on chromosome + +Related Operators +~~~~~~~~~~~~~~~~~ + +- :ref:`MERGE ` - Combine overlapping intervals into single regions +- :ref:`CLUSTER ` - Assign cluster IDs to overlapping intervals diff --git a/docs/dialect/index.rst b/docs/dialect/index.rst index 8d70e9d..ddd7f07 100644 --- a/docs/dialect/index.rst +++ b/docs/dialect/index.rst @@ -95,6 +95,9 @@ Combine and cluster genomic intervals. * - :ref:`MERGE ` - Combine overlapping intervals into unified regions - ``SELECT MERGE(interval) FROM features`` + * - :ref:`RASTERIZE ` + - Rasterize intervals onto a fixed bin grid with per-bin counts + - ``SELECT RASTERIZE(interval, 1000) FROM features`` See :doc:`aggregation-operators` for detailed documentation. diff --git a/docs/recipes/index.rst b/docs/recipes/index.rst index cc97e47..c4b65d6 100644 --- a/docs/recipes/index.rst +++ b/docs/recipes/index.rst @@ -19,6 +19,10 @@ Recipe Categories Clustering overlapping intervals, distance-based clustering, merging intervals, and aggregating cluster statistics. +:doc:`rasterize` + Rasterizing intervals onto a fixed bin grid: per-bin counts, + strand-specific counts, normalisation, and 5' end counting. + :doc:`advanced` Multi-range matching, complex filtering with joins, aggregate statistics, window expansions, and multi-table queries. diff --git a/docs/recipes/rasterize.rst b/docs/recipes/rasterize.rst new file mode 100644 index 0000000..d133874 --- /dev/null +++ b/docs/recipes/rasterize.rst @@ -0,0 +1,145 @@ +Rasterize +========= + +This section covers patterns for projecting interval data onto a fixed-resolution bin grid using GIQL's ``RASTERIZE`` operator. + +Basic Usage +----------- + +Rasterized counts underpin most genome-wide signal summaries — read-pileup plots for ChIP-seq, exon-level depth in RNA-seq, and peak-density overviews across megabases. The recipes below start from a canonical per-bin count and build toward more specialised variants. + +Count Overlapping Features +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Count the number of features overlapping each 1 kb bin across the genome: + +.. code-block:: sql + + SELECT RASTERIZE(interval, 1000) AS depth + FROM features + +**Sample output:** + +.. code-block:: text + + ┌────────┬────────┬────────┬───────┐ + │ chrom │ start │ end │ depth │ + ├────────┼────────┼────────┼───────┤ + │ chr1 │ 0 │ 1000 │ 3 │ + │ chr1 │ 1000 │ 2000 │ 1 │ + │ chr1 │ 2000 │ 3000 │ 0 │ + │ ... │ ... │ ... │ ... │ + └────────┴────────┴────────┴───────┘ + +Each row represents one genomic bin. Bins with no overlapping features appear with a count of zero. An interval that spans more than one bin is counted in each bin it overlaps (the ``bedtools coverage`` convention), so the sum of bin counts is generally greater than the number of source intervals. + +**Use case:** Compute read depth or feature density at a fixed resolution. + +Custom Bin Size +~~~~~~~~~~~~~~~ + +Use a finer resolution of 100 bp: + +.. code-block:: sql + + SELECT RASTERIZE(interval, 100) AS depth + FROM reads + +**Use case:** High-resolution count tracks for visualisation. + +Named Resolution Parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The resolution can also be supplied by name: + +.. code-block:: sql + + SELECT RASTERIZE(interval, resolution := 500) AS depth + FROM features + +Both ``:=`` and ``=>`` are accepted for named parameters. + +.. note:: + + Weighted summary statistics (mean, sum, min, max over interval values, with bin-boundary-aware weighting) are not yet implemented. See the project tracker for the follow-up. + +Filtered Rasterization +---------------------- + +Strand-Specific Counts +~~~~~~~~~~~~~~~~~~~~~~ + +Compute per-bin counts for each strand separately by filtering: + +.. code-block:: sql + + -- Plus strand + SELECT RASTERIZE(interval, 1000) AS depth + FROM features + WHERE strand = '+' + +.. code-block:: sql + + -- Minus strand + SELECT RASTERIZE(interval, 1000) AS depth + FROM features + WHERE strand = '-' + +**Use case:** Strand-specific signal tracks for RNA-seq or stranded assays. + +High-Scoring Features +~~~~~~~~~~~~~~~~~~~~~ + +Restrict counts to features above a quality threshold: + +.. code-block:: sql + + SELECT RASTERIZE(interval, 1000) AS depth + FROM features + WHERE score > 10 + +5' End Counting +~~~~~~~~~~~~~~~ + +To count only the 5' ends of features (e.g. TSS or read starts), first +create a view or CTE that trims each interval to its 5' end, then apply +``RASTERIZE``: + +.. code-block:: sql + + WITH five_prime AS ( + SELECT chrom, "start", "start" + 1 AS "end" + FROM features + WHERE strand = '+' + UNION ALL + SELECT chrom, "end" - 1 AS "start", "end" + FROM features + WHERE strand = '-' + ) + SELECT RASTERIZE(interval, 1000) AS tss_count + FROM five_prime + +Normalised Counts +----------------- + +RPM Normalisation +~~~~~~~~~~~~~~~~~ + +Normalise bin counts to reads per million (RPM) by dividing by the total +number of reads: + +.. code-block:: sql + + WITH bins AS ( + SELECT RASTERIZE(interval, 1000) AS depth + FROM reads + ), + total AS ( + SELECT COUNT(*) AS n FROM reads + ) + SELECT + bins.chrom, + bins.start, + bins.end, + bins.depth * 1000000.0 / total.n AS rpm + FROM bins, total diff --git a/pyproject.toml b/pyproject.toml index 647358b..91ae1c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,9 @@ path = "build-hooks/metadata.py" [tool.pytest.ini_options] addopts = "--cov --cov-config=.coveragerc" +markers = [ + "integration: tests exercising real bedtools subprocesses and DuckDB I/O", +] [tool.ruff] line-length = 89 diff --git a/src/giql/dialect.py b/src/giql/dialect.py index 6c70104..71dde2d 100644 --- a/src/giql/dialect.py +++ b/src/giql/dialect.py @@ -13,6 +13,7 @@ from giql.expressions import Contains from giql.expressions import GIQLCluster +from giql.expressions import GIQLRasterize from giql.expressions import GIQLDistance from giql.expressions import GIQLMerge from giql.expressions import GIQLNearest @@ -54,6 +55,7 @@ class Parser(Parser): FUNCTIONS = { **Parser.FUNCTIONS, "CLUSTER": GIQLCluster.from_arg_list, + "RASTERIZE": GIQLRasterize.from_arg_list, "MERGE": GIQLMerge.from_arg_list, "DISTANCE": GIQLDistance.from_arg_list, "NEAREST": GIQLNearest.from_arg_list, diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 857a223..f5c6d30 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -142,6 +142,33 @@ def from_arg_list(cls, args): return cls(**kwargs) +class GIQLRasterize(exp.Func): + """RASTERIZE aggregate function that projects intervals onto a fixed bin grid. + + Tiles the genome into fixed-width bins and counts the number of + overlapping intervals per bin (bedtools-coverage convention: an + interval that spans multiple bins is counted in each of them). + + Examples: + RASTERIZE(interval, 1000) + RASTERIZE(interval, resolution := 1000) + """ + + arg_types = { + "this": True, # genomic column + "resolution": True, # bin width (positional or named) + } + + @classmethod + def from_arg_list(cls, args): + kwargs, positional_args = _split_named_and_positional(args) + if len(positional_args) > 0: + kwargs["this"] = positional_args[0] + if len(positional_args) > 1: + kwargs["resolution"] = positional_args[1] + return cls(**kwargs) + + class GIQLDistance(exp.Func): """DISTANCE function for calculating genomic distances between intervals. diff --git a/src/giql/transformer.py b/src/giql/transformer.py index ed0b3e1..1965f65 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -15,6 +15,7 @@ from giql.constants import DEFAULT_START_COL from giql.constants import DEFAULT_STRAND_COL from giql.expressions import GIQLCluster +from giql.expressions import GIQLRasterize from giql.expressions import GIQLMerge from giql.expressions import Intersects from giql.table import Tables @@ -50,7 +51,7 @@ def _get_table_name(self, query: exp.Select) -> str | None: :param query: Query to extract table name from - :return: + :returns: Table name if FROM contains a simple table, None otherwise """ from_clause = query.args.get("from_") @@ -67,7 +68,7 @@ def _get_genomic_columns(self, query: exp.Select) -> tuple[str, str, str, str]: :param query: Query to extract table and column info from - :return: + :returns: Tuple of (chrom_col, start_col, end_col, strand_col) """ table_name = self._get_table_name(query) @@ -94,7 +95,7 @@ def transform(self, query: exp.Expression) -> exp.Expression: :param query: Parsed query AST - :return: + :returns: Transformed query AST """ if not isinstance(query, exp.Select): @@ -151,7 +152,7 @@ def _find_cluster_expressions(self, query: exp.Select) -> list[GIQLCluster]: :param query: Query to search - :return: + :returns: List of CLUSTER expressions """ cluster_exprs = [] @@ -175,7 +176,7 @@ def _transform_for_cluster( Original query :param cluster_expr: CLUSTER expression to transform - :return: + :returns: Transformed query with CTEs """ # Extract CLUSTER parameters @@ -378,7 +379,7 @@ def transform(self, query: exp.Expression) -> exp.Expression: :param query: Parsed query AST - :return: + :returns: Transformed query AST """ if not isinstance(query, exp.Select): @@ -436,7 +437,7 @@ def _find_merge_expressions(self, query: exp.Select) -> list[GIQLMerge]: :param query: Query to search - :return: + :returns: List of MERGE expressions """ merge_exprs = [] @@ -459,7 +460,7 @@ def _transform_for_merge( Original query :param merge_expr: MERGE expression to transform - :return: + :returns: Transformed query with clustering and aggregation """ # Extract MERGE parameters (same as CLUSTER) @@ -577,6 +578,10 @@ def _transform_for_merge( exp.Ordered(this=exp.column(start_col, quoted=True)), append=True, copy=False ) + # Preserve any existing CTEs from the original query + if query.args.get("with_"): + final_query.set("with_", query.args["with_"].copy()) + return final_query @@ -1472,3 +1477,409 @@ def _build_join_back_joins( join3 = exp.Join(**join3_kwargs) return [join1, join2, join3] + + +class RasterizeTransformer: + """Transform queries containing RASTERIZE into binned count queries. + + RASTERIZE tiles the genome into fixed-width bins and counts overlapping + intervals per bin: + + SELECT RASTERIZE(interval, 1000) FROM features + + Into: + + WITH __giql_bins AS ( + SELECT chrom, bin_start AS start, bin_start + 1000 AS "end" + FROM ( + SELECT DISTINCT chrom, MAX("end") AS __max_end + FROM features GROUP BY chrom + ) AS __giql_chroms, + LATERAL generate_series(0, __max_end, 1000) AS t(bin_start) + ) + SELECT bins.chrom, bins.start, bins."end", COUNT(source.*) + FROM __giql_bins AS bins + LEFT JOIN features AS source + ON source.start < bins."end" + AND source."end" > bins.start + AND source.chrom = bins.chrom + GROUP BY bins.chrom, bins.start, bins."end" + ORDER BY bins.chrom, bins.start + """ + + def __init__(self, tables: Tables): + """Initialize transformer. + + :param tables: + Table configurations for column mapping + """ + self.tables = tables + self.cluster_transformer = ClusterTransformer(tables) + + def transform(self, query: exp.Expression) -> exp.Expression: + """Transform query if it contains RASTERIZE expressions. + + :param query: + Parsed query AST + :returns: + Transformed query AST + """ + if not isinstance(query, exp.Select): + return query + + # Recursively transform CTEs + if query.args.get("with_"): + cte = query.args["with_"] + for cte_expr in cte.expressions: + if isinstance(cte_expr, exp.CTE): + cte_expr.set("this", self.transform(cte_expr.this)) + + # Recursively transform subqueries in FROM/JOIN/WHERE + for key in ("from_", "where"): + if query.args.get(key): + self._transform_subqueries_in_node(query.args[key]) + if query.args.get("joins"): + for join in query.args["joins"]: + self._transform_subqueries_in_node(join) + + # Find RASTERIZE expressions in SELECT + rasterize_exprs = self._find_rasterize_expressions(query) + if not rasterize_exprs: + return query + + if len(rasterize_exprs) > 1: + raise ValueError("Multiple RASTERIZE expressions not yet supported") + + return self._transform_for_rasterize(query, rasterize_exprs[0]) + + def _get_table_alias(self, query: exp.Select) -> str | None: + """Extract table alias from query's FROM clause. + + :param query: + Query to extract alias from + :returns: + Table alias if present, None otherwise + """ + from_clause = query.args.get("from_") + if not from_clause: + return None + if isinstance(from_clause.this, exp.Table): + return from_clause.this.alias + return None + + def _transform_subqueries_in_node(self, node: exp.Expression): + """Recursively transform subqueries within an expression node. + + :param node: + Expression node to search for subqueries + """ + for subquery in node.find_all(exp.Subquery): + if isinstance(subquery.this, exp.Select): + transformed = self.transform(subquery.this) + subquery.set("this", transformed) + + def _find_rasterize_expressions(self, query: exp.Select) -> list[GIQLRasterize]: + """Find all RASTERIZE expressions in query. + + :param query: + Query to search + :returns: + List of RASTERIZE expressions + """ + rasterize_exprs = [] + for expression in query.expressions: + if isinstance(expression, GIQLRasterize): + rasterize_exprs.append(expression) + elif isinstance(expression, exp.Alias): + if isinstance(expression.this, GIQLRasterize): + rasterize_exprs.append(expression.this) + return rasterize_exprs + + def _transform_for_rasterize( + self, query: exp.Select, rasterize_expr: GIQLRasterize + ) -> exp.Select: + """Transform query to compute RASTERIZE using bins CTE + JOIN + GROUP BY. + + :param query: + Original query + :param rasterize_expr: + RASTERIZE expression to transform + :returns: + Transformed query + """ + # Extract parameters + resolution_expr = rasterize_expr.args.get("resolution") + if isinstance(resolution_expr, exp.Literal): + resolution = int(resolution_expr.this) + elif ( + isinstance(resolution_expr, exp.Neg) + and isinstance(resolution_expr.this, exp.Literal) + ): + resolution = -int(resolution_expr.this.this) + else: + raise ValueError("RASTERIZE resolution must be an integer literal") + + if resolution <= 0: + raise ValueError( + f"RASTERIZE resolution must be positive, got {resolution}" + ) + + # Get column names and table info + chrom_col, start_col, end_col, _ = ( + self.cluster_transformer._get_genomic_columns(query) + ) + table_name = self.cluster_transformer._get_table_name(query) + if not table_name: + raise ValueError( + "RASTERIZE requires a FROM clause that references a table " + "or CTE by name. Inline subqueries and VALUES clauses in " + "FROM are not yet supported — wrap the derivation in a " + "WITH clause (CTE) and select RASTERIZE(...) from the CTE " + "by name instead." + ) + table_alias = self._get_table_alias(query) + source_ref = table_alias or table_name or "source" + + # Build __giql_chroms subquery: + # SELECT DISTINCT chrom, MAX("end") AS __max_end FROM GROUP BY chrom + chroms_select = exp.Select() + chroms_select.select( + exp.column(chrom_col, quoted=True), + copy=False, + ) + chroms_select.select( + exp.alias_( + exp.Max(this=exp.column(end_col, quoted=True)), + "__max_end", + quoted=False, + ), + append=True, + copy=False, + ) + + if table_name: + if table_alias: + chroms_select.from_( + exp.alias_(exp.to_table(table_name), table_alias, table=True), + copy=False, + ) + else: + chroms_select.from_(exp.to_table(table_name), copy=False) + + # Apply WHERE from original query to the chroms subquery too, + # qualifying unqualified column references with the table name + if query.args.get("where"): + chroms_where = query.args["where"].copy() + if table_name: + for col in chroms_where.find_all(exp.Column): + if not col.table: + col.set("table", exp.Identifier(this=table_name)) + chroms_select.set("where", chroms_where) + + chroms_select.group_by(exp.column(chrom_col, quoted=True), copy=False) + + chroms_subquery = exp.Subquery( + this=chroms_select, + alias=exp.TableAlias(this=exp.Identifier(this="__giql_chroms")), + ) + + # Build bins CTE using raw SQL for generate_series + LATERAL + # since SQLGlot doesn't natively support generate_series + bins_select = exp.Select() + bins_select.select( + exp.column(chrom_col, table="__giql_chroms", quoted=True), + copy=False, + ) + bins_select.select( + exp.alias_( + exp.column("bin_start"), + start_col, + quoted=True, + ), + append=True, + copy=False, + ) + bins_select.select( + exp.alias_( + exp.Add( + this=exp.column("bin_start"), + expression=exp.Literal.number(resolution), + ), + end_col, + quoted=True, + ), + append=True, + copy=False, + ) + + # FROM __giql_chroms subquery + bins_select.from_(chroms_subquery, copy=False) + + # CROSS JOIN LATERAL generate_series(0, __max_end - 1, resolution) + # AS t(bin_start) — upper bound subtracts 1 because generate_series + # is endpoint-inclusive and we don't want a trailing empty bin when + # MAX(end) lands exactly on a bin boundary. + lateral_join = exp.Join( + this=exp.Lateral( + this=exp.Anonymous( + this="generate_series", + expressions=[ + exp.Literal.number(0), + exp.Sub( + this=exp.column("__max_end"), + expression=exp.Literal.number(1), + ), + exp.Literal.number(resolution), + ], + ), + alias=exp.TableAlias( + this=exp.Identifier(this="t"), + columns=[exp.Identifier(this="bin_start")], + ), + ), + kind="CROSS", + ) + bins_select.append("joins", lateral_join) + + # Wrap bins_select as a CTE named __giql_bins + bins_cte = exp.CTE( + this=bins_select, + alias=exp.TableAlias(this=exp.Identifier(this="__giql_bins")), + ) + with_clause = exp.With(expressions=[bins_cte]) + + # COUNT(chrom) — null-safe count of intervals overlapping the bin. + # Counting a non-null source column gives 0 for empty bins (LEFT JOIN + # produces NULLs for non-matches, which COUNT excludes). + agg_expr = exp.Count( + this=exp.column(chrom_col, table=source_ref, quoted=True), + ) + + # Build main SELECT + final_query = exp.Select() + + # Add bin coordinate columns + final_query.select( + exp.column(chrom_col, table="bins", quoted=True), + copy=False, + ) + final_query.select( + exp.column(start_col, table="bins", quoted=True), + append=True, + copy=False, + ) + final_query.select( + exp.column(end_col, table="bins", quoted=True), + append=True, + copy=False, + ) + + # Replace RASTERIZE(...) in select list with aggregate, and add other columns + for expression in query.expressions: + if isinstance(expression, GIQLRasterize): + final_query.select( + exp.alias_(agg_expr, "value", quoted=False), + append=True, + copy=False, + ) + elif isinstance(expression, exp.Alias) and isinstance( + expression.this, GIQLRasterize + ): + final_query.select( + exp.alias_(agg_expr, expression.alias, quoted=False), + append=True, + copy=False, + ) + else: + final_query.select(expression, append=True, copy=False) + + # FROM __giql_bins AS bins + final_query.from_( + exp.Table( + this=exp.Identifier(this="__giql_bins"), + alias=exp.TableAlias(this=exp.Identifier(this="bins")), + ), + copy=False, + ) + + # LEFT JOIN source ON overlap conditions + source_table = exp.to_table(table_name) + if table_alias: + source_table.set( + "alias", exp.TableAlias(this=exp.Identifier(this=source_ref)) + ) + + join_condition = exp.And( + this=exp.And( + this=exp.LT( + this=exp.column(start_col, table=source_ref, quoted=True), + expression=exp.column(end_col, table="bins", quoted=True), + ), + expression=exp.GT( + this=exp.column(end_col, table=source_ref, quoted=True), + expression=exp.column(start_col, table="bins", quoted=True), + ), + ), + expression=exp.EQ( + this=exp.column(chrom_col, table=source_ref, quoted=True), + expression=exp.column(chrom_col, table="bins", quoted=True), + ), + ) + + # Merge original WHERE into the JOIN ON condition so that + # LEFT JOIN still produces zero-coverage bins (WHERE would filter + # them out because source columns are NULL for non-matching bins) + if query.args.get("where"): + where_condition = query.args["where"].this.copy() + # Qualify unqualified column references with source_ref + for col in where_condition.find_all(exp.Column): + if not col.table: + col.set("table", exp.Identifier(this=source_ref)) + join_condition = exp.And( + this=join_condition, + expression=where_condition, + ) + + left_join = exp.Join( + this=source_table, + on=join_condition, + kind="LEFT", + ) + final_query.append("joins", left_join) + + # GROUP BY bins.chrom, bins.start, bins.end + final_query.group_by( + exp.column(chrom_col, table="bins", quoted=True), + copy=False, + ) + final_query.group_by( + exp.column(start_col, table="bins", quoted=True), + append=True, + copy=False, + ) + final_query.group_by( + exp.column(end_col, table="bins", quoted=True), + append=True, + copy=False, + ) + + # ORDER BY bins.chrom, bins.start + final_query.order_by( + exp.Ordered(this=exp.column(chrom_col, table="bins", quoted=True)), + copy=False, + ) + final_query.order_by( + exp.Ordered(this=exp.column(start_col, table="bins", quoted=True)), + append=True, + copy=False, + ) + + # Attach the WITH clause, preserving any user CTEs from the input query + existing_with = query.args.get("with_") + if existing_with: + merged_ctes = [cte.copy() for cte in existing_with.expressions] + [bins_cte] + final_query.set("with_", exp.With(expressions=merged_ctes)) + else: + final_query.set("with_", with_clause) + + return final_query diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 7c70746..d01a6aa 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from giql.table import Table from giql.table import Tables from giql.transformer import ClusterTransformer +from giql.transformer import RasterizeTransformer from giql.transformer import IntersectsBinnedJoinTransformer from giql.transformer import MergeTransformer @@ -120,6 +121,7 @@ def transpile( tables_container, bin_size=intersects_bin_size, ) + rasterize_transformer = RasterizeTransformer(tables_container) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) @@ -135,6 +137,8 @@ def transpile( # Apply transformations try: ast = intersects_transformer.transform(ast) + # RASTERIZE transformation (independent) + ast = rasterize_transformer.transform(ast) # MERGE transformation (which may internally use CLUSTER) ast = merge_transformer.transform(ast) # CLUSTER transformation for any standalone CLUSTER expressions diff --git a/tests/integration/bedtools/conftest.py b/tests/integration/bedtools/conftest.py index 79994a1..ae402b5 100644 --- a/tests/integration/bedtools/conftest.py +++ b/tests/integration/bedtools/conftest.py @@ -15,8 +15,6 @@ allow_module_level=True, ) -pytestmark = pytest.mark.integration - from .utils.duckdb_loader import load_intervals # noqa: E402 diff --git a/tests/integration/bedtools/test_cluster.py b/tests/integration/bedtools/test_cluster.py index c492f0d..364caf6 100644 --- a/tests/integration/bedtools/test_cluster.py +++ b/tests/integration/bedtools/test_cluster.py @@ -6,12 +6,16 @@ number of distinct clusters should equal the number of merged intervals. """ +import pytest + from giql import transpile from .utils.bedtools_wrapper import merge from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + def test_cluster_basic(duckdb_connection): """ diff --git a/tests/integration/bedtools/test_contains.py b/tests/integration/bedtools/test_contains.py index 6325e43..87fe584 100644 --- a/tests/integration/bedtools/test_contains.py +++ b/tests/integration/bedtools/test_contains.py @@ -5,8 +5,12 @@ equivalent exists, so tests validate against known expected results. """ +import pytest + from .utils.data_models import GenomicInterval +pytestmark = pytest.mark.integration + def test_contains_point(giql_query): """ diff --git a/tests/integration/bedtools/test_correctness_intersect.py b/tests/integration/bedtools/test_correctness_intersect.py new file mode 100644 index 0000000..dedbbfa --- /dev/null +++ b/tests/integration/bedtools/test_correctness_intersect.py @@ -0,0 +1,307 @@ +"""Extended correctness tests for GIQL INTERSECTS operator vs bedtools intersect. + +These tests cover boundary cases, scale, and edge scenarios beyond the basic +tests in test_intersect.py, ensuring comprehensive GIQL/bedtools equivalence. +""" + +import pytest + +from giql import transpile + +from .utils.bedtools_wrapper import intersect +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals +from .utils.random_intervals import generate_random_intervals + +pytestmark = pytest.mark.integration + + +def _run_intersect_comparison( + duckdb_connection, + intervals_a, + intervals_b, + strand_filter="", +): + """Run GIQL INTERSECTS and bedtools intersect, return ComparisonResult.""" + load_intervals( + duckdb_connection, + "intervals_a", + [i.to_tuple() for i in intervals_a], + ) + load_intervals( + duckdb_connection, + "intervals_b", + [i.to_tuple() for i in intervals_b], + ) + + strand_mode = None + if "a.strand = b.strand" in strand_filter: + strand_mode = "same" + elif "a.strand != b.strand" in strand_filter: + strand_mode = "opposite" + + bedtools_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode=strand_mode, + ) + + where_clause = "WHERE a.interval INTERSECTS b.interval" + if strand_filter: + where_clause += f" AND {strand_filter}" + + sql = transpile( + f""" + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + {where_clause} + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + return compare_results(giql_result, bedtools_result) + + +def test_intersects_should_match_bedtools_when_overlap_is_one_bp(duckdb_connection): + """Test INTERSECTS matches bedtools for a minimal 1bp overlap. + + Given: + Two intervals that overlap by exactly one base pair + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should detect the 1bp overlap identically to bedtools + """ + # Arrange + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 199, 300, "b1", 0, "+")] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_match_bedtools_when_a_contains_b(duckdb_connection): + """Test INTERSECTS matches bedtools when A fully contains B. + + Given: + Interval A that fully contains interval B + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should report A as intersecting B + """ + # Arrange + a = [GenomicInterval("chr1", 100, 500, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 300, "b1", 0, "+")] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_match_bedtools_when_b_contains_a(duckdb_connection): + """Test INTERSECTS matches bedtools when B fully contains A. + + Given: + Interval B that fully contains interval A + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should report A as intersecting B + """ + # Arrange + a = [GenomicInterval("chr1", 200, 300, "a1", 0, "+")] + b = [GenomicInterval("chr1", 100, 500, "b1", 0, "+")] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_deduplicate_when_a_overlaps_multiple_b(duckdb_connection): + """Test INTERSECTS with DISTINCT matches bedtools -u deduplication. + + Given: + One interval in A that overlaps several intervals in B + When: + GIQL INTERSECTS with DISTINCT is compared to bedtools intersect -u + Then: + It should report the A interval exactly once + """ + # Arrange + a = [GenomicInterval("chr1", 100, 300, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 150, 200, "b1", 0, "+"), + GenomicInterval("chr1", 200, 250, "b2", 0, "+"), + GenomicInterval("chr1", 250, 350, "b3", 0, "+"), + ] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_match_bedtools_when_chromosome_names_are_non_standard( + duckdb_connection, +): + """Test INTERSECTS matches bedtools on non-standard chromosome names. + + Given: + Intervals on non-standard chromosome names like chrM and chrUn + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should match bedtools regardless of chromosome naming + """ + # Arrange + a = [ + GenomicInterval("chrM", 100, 200, "a1", 0, "+"), + GenomicInterval("chrUn", 100, 200, "a2", 0, "+"), + ] + b = [ + GenomicInterval("chrM", 150, 250, "b1", 0, "+"), + GenomicInterval("chrUn", 150, 250, "b2", 0, "+"), + ] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 2 + + +def test_intersects_should_match_bedtools_when_intervals_are_very_large( + duckdb_connection, +): + """Test INTERSECTS matches bedtools for multi-megabase intervals. + + Given: + Very large genomic intervals spanning millions of bases + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should produce the same overlap result as bedtools + """ + # Arrange + a = [GenomicInterval("chr1", 0, 10_000_000, "a1", 0, "+")] + b = [GenomicInterval("chr1", 5_000_000, 15_000_000, "b1", 0, "+")] + + # Act + comparison = _run_intersect_comparison(duckdb_connection, a, b) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_match_bedtools_at_scale(duckdb_connection): + """Test INTERSECTS matches bedtools on a larger generated dataset. + + Given: + A generated dataset with 100 intervals per chromosome on 3 chromosomes + When: + GIQL INTERSECTS is compared to bedtools intersect + Then: + It should match bedtools on the full dataset + """ + # Arrange + intervals_a = generate_random_intervals( + seed=42, + prefix="a", + count_per_chrom=100, + n_chroms=3, + start_max=900_000, + ) + intervals_b = generate_random_intervals( + seed=43, + prefix="b", + count_per_chrom=100, + n_chroms=3, + start_max=900_000, + ) + + # Act + comparison = _run_intersect_comparison(duckdb_connection, intervals_a, intervals_b) + + # Assert + assert comparison.match, comparison.failure_message() + + +def test_intersects_should_match_bedtools_when_same_strand_filter_applied( + duckdb_connection, +): + """Test INTERSECTS with same-strand filter matches bedtools -s. + + Given: + Overlapping intervals with mixed strand orientations + When: + GIQL INTERSECTS with a same-strand filter is compared to bedtools -s + Then: + It should return only the same-strand overlaps + """ + # Arrange + a = [ + GenomicInterval("chr1", 100, 200, "a_plus", 0, "+"), + GenomicInterval("chr1", 100, 200, "a_minus", 0, "-"), + ] + b = [GenomicInterval("chr1", 150, 250, "b_plus", 0, "+")] + + # Act + comparison = _run_intersect_comparison( + duckdb_connection, + a, + b, + strand_filter="a.strand = b.strand", + ) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_intersects_should_match_bedtools_when_opposite_strand_filter_applied( + duckdb_connection, +): + """Test INTERSECTS with opposite-strand filter matches bedtools -S. + + Given: + Overlapping intervals with mixed strand orientations + When: + GIQL INTERSECTS with an opposite-strand filter is compared to bedtools -S + Then: + It should return only the opposite-strand overlaps + """ + # Arrange + a = [ + GenomicInterval("chr1", 100, 200, "a_plus", 0, "+"), + GenomicInterval("chr1", 100, 200, "a_minus", 0, "-"), + ] + b = [GenomicInterval("chr1", 150, 250, "b_plus", 0, "+")] + + # Act + comparison = _run_intersect_comparison( + duckdb_connection, + a, + b, + strand_filter="a.strand != b.strand", + ) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 diff --git a/tests/integration/bedtools/test_correctness_merge.py b/tests/integration/bedtools/test_correctness_merge.py new file mode 100644 index 0000000..f30634e --- /dev/null +++ b/tests/integration/bedtools/test_correctness_merge.py @@ -0,0 +1,295 @@ +"""Extended correctness tests for GIQL MERGE operator vs bedtools merge. + +These tests cover transitive chains, topology variations, and scale scenarios +to ensure comprehensive GIQL/bedtools equivalence for merge operations. +""" + +import pytest + +from giql import transpile + +from .utils.bedtools_wrapper import merge +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals +from .utils.random_intervals import generate_random_intervals + +pytestmark = pytest.mark.integration + + +def _run_merge_comparison(duckdb_connection, intervals, strand_mode=None): + """Run GIQL MERGE and bedtools merge, return ComparisonResult.""" + load_intervals( + duckdb_connection, + "intervals", + [i.to_tuple() for i in intervals], + ) + + bedtools_result = merge( + [i.to_tuple() for i in intervals], + strand_mode=strand_mode, + ) + + if strand_mode == "same": + giql_sql = "SELECT MERGE(interval, stranded := true) FROM intervals" + else: + giql_sql = "SELECT MERGE(interval) FROM intervals" + + sql = transpile(giql_sql, tables=["intervals"]) + giql_result = duckdb_connection.execute(sql).fetchall() + + return compare_results(giql_result, bedtools_result) + + +def test_merge_should_combine_transitive_chain_into_single_interval(duckdb_connection): + """Test MERGE collapses a transitive overlap chain. + + Given: + A chain A overlaps B, B overlaps C (but A does not overlap C + directly) + When: + GIQL MERGE is compared to bedtools merge + Then: + It should merge the entire chain into a single interval + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 180, 300, "i2", 0, "+"), + GenomicInterval("chr1", 280, 400, "i3", 0, "+"), + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_should_return_interval_unchanged_when_input_is_single_interval( + duckdb_connection, +): + """Test MERGE is a no-op for a single-interval input. + + Given: + A single interval + When: + GIQL MERGE is compared to bedtools merge + Then: + It should return the single interval unchanged + """ + # Arrange + intervals = [GenomicInterval("chr1", 100, 200, "i1", 0, "+")] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_should_produce_one_region_when_all_intervals_overlap( + duckdb_connection, +): + """Test MERGE collapses fully overlapping intervals into one region. + + Given: + All intervals on a chromosome overlap forming one big region + When: + GIQL MERGE is compared to bedtools merge + Then: + It should return a single merged interval + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 100, 500, "i1", 0, "+"), + GenomicInterval("chr1", 200, 400, "i2", 0, "+"), + GenomicInterval("chr1", 300, 600, "i3", 0, "+"), + GenomicInterval("chr1", 150, 550, "i4", 0, "+"), + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_should_return_correct_region_count_when_topology_is_mixed( + duckdb_connection, +): + """Test MERGE handles a mix of overlapping clusters and isolated intervals. + + Given: + A mix of overlapping clusters and isolated intervals + When: + GIQL MERGE is compared to bedtools merge + Then: + It should produce the correct number of merged regions + """ + # Arrange + intervals = [ + # Cluster 1: overlapping + GenomicInterval("chr1", 100, 200, "c1a", 0, "+"), + GenomicInterval("chr1", 150, 300, "c1b", 0, "+"), + # Isolated + GenomicInterval("chr1", 500, 600, "iso", 0, "+"), + # Cluster 2: overlapping + GenomicInterval("chr1", 800, 900, "c2a", 0, "+"), + GenomicInterval("chr1", 850, 1000, "c2b", 0, "+"), + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 3 + + +def test_merge_should_combine_intervals_when_overlap_is_one_base(duckdb_connection): + """Test MERGE triggers on a single-base overlap. + + Given: + Intervals with exactly 1bp overlap + When: + GIQL MERGE is compared to bedtools merge + Then: + It should treat the 1bp overlap as sufficient to merge + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 199, 300, "i2", 0, "+"), + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 1 + + +def test_merge_should_match_bedtools_when_input_is_unsorted(duckdb_connection): + """Test MERGE is insensitive to input ordering. + + Given: + Intervals inserted in non-sorted order + When: + GIQL MERGE is compared to bedtools merge + Then: + It should produce the same results regardless of input order + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 400, 500, "i3", 0, "+"), + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 150, 250, "i2", 0, "+"), + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + + +def test_merge_should_operate_per_chromosome_when_input_spans_multiple_chromosomes( + duckdb_connection, +): + """Test MERGE groups merges per chromosome. + + Given: + Overlapping intervals on separate chromosomes + When: + GIQL MERGE is compared to bedtools merge + Then: + It should merge per-chromosome independently + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 100, 200, "c1a", 0, "+"), + GenomicInterval("chr1", 150, 300, "c1b", 0, "+"), + GenomicInterval("chr2", 100, 200, "c2a", 0, "+"), + GenomicInterval("chr2", 150, 300, "c2b", 0, "+"), + GenomicInterval("chr3", 100, 200, "c3", 0, "+"), # no overlap + ] + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() + assert comparison.giql_row_count == 3 # 1 per chrom + + +def test_merge_should_preserve_strand_when_stranded_true(duckdb_connection): + """Test MERGE with stranded=true matches bedtools merge -s. + + Given: + Overlapping intervals on different strands + When: + GIQL MERGE(stranded=true) is compared to bedtools merge -s + Then: + It should produce the same per-strand merge count as bedtools + """ + # Arrange + intervals = [ + GenomicInterval("chr1", 100, 200, "i1", 0, "+"), + GenomicInterval("chr1", 150, 250, "i2", 0, "+"), + GenomicInterval("chr1", 120, 220, "i3", 0, "-"), + GenomicInterval("chr1", 180, 280, "i4", 0, "-"), + ] + load_intervals( + duckdb_connection, + "intervals", + [i.to_tuple() for i in intervals], + ) + + bedtools_result = merge( + [i.to_tuple() for i in intervals], + strand_mode="same", + ) + + sql = transpile( + "SELECT MERGE(interval, stranded := true) FROM intervals", + tables=["intervals"], + ) + + # Act + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + # Both should have 2 merged intervals (one per strand) + assert len(giql_result) == len(bedtools_result) + + +def test_merge_should_match_bedtools_when_dataset_is_large(duckdb_connection): + """Test MERGE agrees with bedtools on a large synthetic dataset. + + Given: + 100+ intervals across 3 chromosomes + When: + GIQL MERGE is compared to bedtools merge + Then: + It should produce results matching bedtools on the full dataset + """ + # Arrange + intervals = generate_random_intervals( + seed=42, + prefix="chr", + count_per_chrom=100, + n_chroms=3, + start_max=500_000, + max_size=2000, + ) + + # Act + comparison = _run_merge_comparison(duckdb_connection, intervals) + + # Assert + assert comparison.match, comparison.failure_message() diff --git a/tests/integration/bedtools/test_correctness_nearest.py b/tests/integration/bedtools/test_correctness_nearest.py new file mode 100644 index 0000000..332ad32 --- /dev/null +++ b/tests/integration/bedtools/test_correctness_nearest.py @@ -0,0 +1,381 @@ +"""Extended correctness tests for GIQL NEAREST operator vs bedtools closest. + +These tests cover distance calculations, multi-query scenarios, and scale +to ensure comprehensive GIQL/bedtools equivalence for nearest operations. +""" + +import pytest + +from giql import transpile + +from .utils.bedtools_wrapper import closest +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals +from .utils.random_intervals import generate_random_intervals + +pytestmark = pytest.mark.integration + + +def _load_and_query_nearest( + duckdb_connection, + intervals_a, + intervals_b, + *, + k=1, + stranded=False, +): + """Load intervals, run GIQL NEAREST and bedtools closest, return both results.""" + load_intervals( + duckdb_connection, + "intervals_a", + [i.to_tuple() for i in intervals_a], + ) + load_intervals( + duckdb_connection, + "intervals_b", + [i.to_tuple() for i in intervals_b], + ) + + strand_mode = "same" if stranded else None + bedtools_result = closest( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode=strand_mode, + k=k, + ) + + stranded_arg = ", stranded := true" if stranded else "" + sql = transpile( + f""" + SELECT a.*, b.* + FROM intervals_a a + CROSS JOIN LATERAL NEAREST( + intervals_b, + reference := a.interval, + k := {k}{stranded_arg} + ) b + ORDER BY a.chrom, a.start + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + return giql_result, bedtools_result + + +def test_nearest_should_report_distance_zero_when_intervals_overlap(duckdb_connection): + """Test NEAREST reports zero distance for overlapping intervals. + + Given: + Overlapping intervals in A and B + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should report distance=0 for the overlapping pair + """ + # Arrange + a = [GenomicInterval("chr1", 100, 300, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 400, "b1", 0, "+")] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + # bedtools closest -d reports 0 for overlapping + assert bedtools_result[0][-1] == 0 + + +def test_nearest_should_find_adjacent_neighbor_when_intervals_touch( + duckdb_connection, +): + """Test NEAREST matches bedtools for adjacent non-overlapping intervals. + + Given: + Two adjacent intervals in half-open coordinates (a1 ending at + 200, b1 starting at 200 — touching but not overlapping) + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should identify b1 as a1's nearest neighbor, and bedtools + should report the canonical adjacent-interval distance of 1 + (bedtools >= 2.31 counts the gap base in half-open coords) + """ + # Arrange + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 200, 300, "b1", 0, "+")] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + assert bedtools_result[0][-1] == 1 + assert giql_result[0][9] == "b1" + + +def test_nearest_should_match_bedtools_when_candidate_is_upstream(duckdb_connection): + """Test NEAREST matches bedtools for an upstream candidate interval. + + Given: + A B interval positioned far upstream of the A interval + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should identify the upstream candidate with the correct distance + """ + # Arrange + a = [GenomicInterval("chr1", 500, 600, "a1", 0, "+")] + b = [GenomicInterval("chr1", 100, 200, "b1", 0, "+")] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + # Distance: 500 - 200 = 300 (half-open), bedtools may report 301 + assert bedtools_result[0][-1] in (300, 301) + assert giql_result[0][9] == "b1" + + +def test_nearest_should_match_bedtools_when_candidate_is_downstream(duckdb_connection): + """Test NEAREST matches bedtools for a downstream candidate interval. + + Given: + A B interval positioned far downstream of the A interval + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should identify the downstream candidate with the correct distance + """ + # Arrange + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [GenomicInterval("chr1", 500, 600, "b1", 0, "+")] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + # Distance: 500 - 200 = 300 (half-open), bedtools may report 301 + assert bedtools_result[0][-1] in (300, 301) + assert giql_result[0][9] == "b1" + + +def test_nearest_should_match_bedtools_for_multiple_query_intervals(duckdb_connection): + """Test NEAREST matches bedtools when multiple query intervals are used. + + Given: + Multiple query intervals in A and multiple candidates in B + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should produce the correct pairing for each query interval + """ + # Arrange + a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 500, 600, "a2", 0, "+"), + GenomicInterval("chr1", 900, 1000, "a3", 0, "+"), + ] + b = [ + GenomicInterval("chr1", 250, 300, "b1", 0, "+"), + GenomicInterval("chr1", 700, 800, "b2", 0, "+"), + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 3 + + giql_sorted = sorted(giql_result, key=lambda r: (r[0], r[1])) + bt_sorted = sorted(bedtools_result, key=lambda r: (r[0], r[1])) + + for giql_row, bt_row in zip(giql_sorted, bt_sorted): + assert giql_row[3] == bt_row[3] # a.name matches + assert giql_row[9] == bt_row[9] # b.name matches + + +def test_nearest_should_return_three_neighbors_when_k_is_three(duckdb_connection): + """Test NEAREST returns the three nearest neighbors when k=3. + + Given: + One query interval and four database candidates + When: + GIQL NEAREST(k=3) is compared to bedtools closest -k 3 + Then: + It should return the same three nearest intervals as bedtools + """ + # Arrange + a = [GenomicInterval("chr1", 400, 500, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 100, 150, "b_far", 0, "+"), + GenomicInterval("chr1", 350, 390, "b_near", 0, "+"), + GenomicInterval("chr1", 550, 600, "b_close", 0, "+"), + GenomicInterval("chr1", 900, 1000, "b_farther", 0, "+"), + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b, k=3) + + # Assert + assert len(giql_result) == 3 + assert len(bedtools_result) == 3 + + giql_names = {r[9] for r in giql_result} + bt_names = {r[9] for r in bedtools_result} + assert giql_names == bt_names + + +def test_nearest_should_return_available_neighbors_when_k_exceeds_candidates(duckdb_connection): + """Test NEAREST caps results at the number of available candidates. + + Given: + One query interval, only two database candidates, and k=5 + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should return only the two available candidates, matching bedtools + """ + # Arrange + a = [GenomicInterval("chr1", 200, 300, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 100, 150, "b1", 0, "+"), + GenomicInterval("chr1", 400, 500, "b2", 0, "+"), + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b, k=5) + + # Assert + assert len(giql_result) == len(bedtools_result) == 2 + + +def test_nearest_should_return_only_same_strand_candidates_when_stranded(duckdb_connection): + """Test NEAREST restricts matches to same strand when stranded=true. + + Given: + Candidates on same and opposite strands, with the opposite-strand + candidate being closer + When: + GIQL NEAREST(stranded=true) is compared to bedtools closest -s + Then: + It should return only the same-strand match, ignoring the closer + opposite-strand candidate + """ + # Arrange + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 220, 240, "b_opp", 0, "-"), # closer, opposite + GenomicInterval("chr1", 300, 400, "b_same", 0, "+"), # farther, same + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest( + duckdb_connection, + a, + b, + stranded=True, + ) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + assert giql_result[0][9] == "b_same" + assert bedtools_result[0][9] == "b_same" + + +def test_nearest_should_ignore_strand_when_unstranded(duckdb_connection): + """Test NEAREST ignores strand when not configured as stranded. + + Given: + Candidates on different strands where the closer one is on the + opposite strand + When: + GIQL NEAREST (default) is compared to bedtools closest (default) + Then: + It should return the nearest candidate regardless of strand + """ + # Arrange + a = [GenomicInterval("chr1", 100, 200, "a1", 0, "+")] + b = [ + GenomicInterval("chr1", 250, 300, "b_far", 0, "+"), + GenomicInterval("chr1", 210, 230, "b_near", 0, "-"), + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 1 + assert giql_result[0][9] == "b_near" + assert bedtools_result[0][9] == "b_near" + + +def test_nearest_should_isolate_matches_per_chromosome(duckdb_connection): + """Test NEAREST only pairs intervals on the same chromosome. + + Given: + Intervals distributed across multiple chromosomes + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should find nearest matches only within each chromosome + """ + # Arrange + a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr2", 100, 200, "a2", 0, "+"), + ] + b = [ + GenomicInterval("chr1", 500, 600, "b1", 0, "+"), + GenomicInterval("chr2", 300, 400, "b2", 0, "+"), + ] + + # Act + giql_result, bedtools_result = _load_and_query_nearest(duckdb_connection, a, b) + + # Assert + assert len(giql_result) == len(bedtools_result) == 2 + + for giql_row in giql_result: + assert giql_row[0] == giql_row[6], "A and B should be on same chromosome" + + +def test_nearest_should_match_bedtools_on_large_multi_chromosome_dataset(duckdb_connection): + """Test NEAREST matches bedtools on a large multi-chromosome dataset. + + Given: + Fifty-plus intervals per table spread across three chromosomes + When: + GIQL NEAREST is compared to bedtools closest + Then: + It should produce the same row count as bedtools on the full dataset + """ + # Arrange + intervals_a = generate_random_intervals( + seed=42, + prefix="a", + count_per_chrom=50, + n_chroms=3, + start_max=900_000, + ) + intervals_b = generate_random_intervals( + seed=43, + prefix="b", + count_per_chrom=50, + n_chroms=3, + start_max=900_000, + ) + + # Act + giql_result, bedtools_result = _load_and_query_nearest( + duckdb_connection, + intervals_a, + intervals_b, + ) + + # Assert + assert len(giql_result) == len(bedtools_result), ( + f"Row count mismatch: GIQL={len(giql_result)}, bedtools={len(bedtools_result)}" + ) diff --git a/tests/integration/bedtools/test_correctness_workflows.py b/tests/integration/bedtools/test_correctness_workflows.py new file mode 100644 index 0000000..fc278ec --- /dev/null +++ b/tests/integration/bedtools/test_correctness_workflows.py @@ -0,0 +1,398 @@ +"""Integration correctness tests for multi-operation GIQL workflows. + +These tests validate that chained GIQL operations produce results matching +equivalent bedtools command pipelines. Corresponds to User Story 4 (P3) +from the bedtools integration test spec. +""" + +import pytest + +from giql import transpile + +from .utils.bedtools_wrapper import closest +from .utils.bedtools_wrapper import intersect +from .utils.bedtools_wrapper import merge +from .utils.comparison import compare_results +from .utils.data_models import GenomicInterval +from .utils.duckdb_loader import load_intervals + +pytestmark = pytest.mark.integration + + +def test_pipeline_should_match_bedtools_when_intersect_chained_into_merge(duckdb_connection): + """Test that chaining intersect into merge in GIQL matches the bedtools pipeline. + + Given: + Two interval sets with overlaps on chr1 + When: + GIQL intersects via CTE and then merges, compared against + bedtools intersect piped into bedtools merge + Then: + It should produce identical merged intervals + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr1", 500, 600, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr1", 520, 580, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Act + # bedtools pipeline: intersect then merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + bedtools_final = merge(intersect_result) + + # GIQL: use CTE to intersect, then merge + sql = transpile( + """ + WITH hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + ) + SELECT MERGE(interval) + FROM hits + """, + tables=["intervals_a", "intervals_b", "hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_pipeline_should_filter_by_distance_when_nearest_max_distance_applied(duckdb_connection): + """Test that NEAREST with max_distance matches bedtools closest filtered by distance. + + Given: + Two interval sets where one B interval is within 50bp of an A + interval and another is far beyond that threshold + When: + GIQL runs NEAREST with max_distance := 50, compared against + bedtools closest -d post-filtered to distance <= 50 + Then: + It should return only the close neighbor pair and drop the far one + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 500, 600, "a2", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 220, 250, "b_near", 0, "+"), # 20bp from a1 + GenomicInterval("chr1", 900, 1000, "b_far", 0, "+"), # 300bp from a2 + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Act + # bedtools: closest -d, then filter distance <= 50 + bt_result = closest( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + bedtools_filtered = [row for row in bt_result if row[-1] <= 50] + + # GIQL: NEAREST with max_distance + sql = transpile( + """ + SELECT a.name, b.name + FROM intervals_a a + CROSS JOIN LATERAL NEAREST( + intervals_b, + reference := a.interval, + k := 1, + max_distance := 50 + ) b + """, + tables=["intervals_a", "intervals_b"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + # Both should return only a1->b_near (distance 20 <= 50) + # a2->b_far (distance 300 > 50) should be excluded + assert len(giql_result) == len(bedtools_filtered) + if len(giql_result) > 0: + giql_names = {r[0] for r in giql_result} + assert "a1" in giql_names + + +def test_pipeline_should_match_bedtools_when_merge_chained_into_intersect(duckdb_connection): + """Test that merging then intersecting in GIQL matches the bedtools pipeline. + + Given: + An A interval set with overlapping intervals and a disjoint B set + When: + GIQL merges A via CTE then intersects with B, compared against + bedtools merge of A piped into bedtools intersect against B + Then: + It should produce matching interval coordinates + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 180, 300, "a2", 0, "+"), + GenomicInterval("chr1", 500, 600, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 250, 350, "b1", 0, "+"), + GenomicInterval("chr1", 550, 650, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Act + # bedtools pipeline: merge a, then intersect with b + merged_a = merge([i.to_tuple() for i in intervals_a]) + bedtools_final = intersect(merged_a, [i.to_tuple() for i in intervals_b]) + + # GIQL: CTE to merge, then intersect + sql = transpile( + """ + WITH merged AS ( + SELECT MERGE(interval) AS interval + FROM intervals_a + ) + SELECT DISTINCT m.* + FROM merged m, intervals_b b + WHERE m.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b", "merged"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + # MERGE outputs BED3 (chrom, start, end); compare only coordinates + bedtools_coords = [row[:3] for row in bedtools_final] + comparison = compare_results(giql_result, bedtools_coords) + assert comparison.match, comparison.failure_message() + + +def test_pipeline_should_preserve_strand_when_intersect_then_merge_stranded(duckdb_connection): + """Test that strand-aware intersect chained into merge matches the bedtools pipeline. + + Given: + Two interval sets carrying strand information, with mixed plus + and minus strand overlaps + When: + GIQL performs a same-strand intersect via CTE then merges, + compared against bedtools intersect -s piped into bedtools merge + Then: + It should produce matching merged intervals honoring strand + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr1", 120, 250, "a3", 0, "-"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr1", 130, 220, "b2", 0, "-"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Act + # bedtools pipeline: intersect -s then merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + strand_mode="same", + ) + bedtools_final = merge(intersect_result) + + # GIQL: same-strand intersect via CTE then merge + sql = transpile( + """ + WITH hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.strand = b.strand + ) + SELECT MERGE(interval) + FROM hits + """, + tables=["intervals_a", "intervals_b", "hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_pipeline_should_match_bedtools_when_intersect_chrom_filter_then_merge(duckdb_connection): + """Test that intersect followed by a chr1 filter and a merge matches the bedtools pipeline. + + Given: + Two interval sets spanning chr1 and chr2 with overlaps on both + chromosomes + When: + GIQL intersects with a chrom = 'chr1' predicate inside a CTE + and then merges, compared against bedtools intersect, then a + Python-side chr1 filter, then bedtools merge + Then: + It should produce matching merged intervals restricted to chr1 + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 200, "a1", 0, "+"), + GenomicInterval("chr1", 150, 300, "a2", 0, "+"), + GenomicInterval("chr2", 100, 200, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 180, 250, "b1", 0, "+"), + GenomicInterval("chr2", 150, 250, "b2", 0, "+"), + ] + + load_intervals(duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a]) + load_intervals(duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b]) + + # Act + # bedtools pipeline: intersect, filter chr1, merge + intersect_result = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + chr1_only = [r for r in intersect_result if r[0] == "chr1"] + bedtools_final = merge(chr1_only) if chr1_only else [] + + # GIQL: CTE intersect with chr1 filter, then merge + sql = transpile( + """ + WITH chr1_hits AS ( + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + AND a.chrom = 'chr1' + ) + SELECT MERGE(interval) + FROM chr1_hits + """, + tables=["intervals_a", "intervals_b", "chr1_hits"], + ) + giql_result = duckdb_connection.execute(sql).fetchall() + + # Assert + comparison = compare_results(giql_result, bedtools_final) + assert comparison.match, comparison.failure_message() + + +def test_pipeline_should_match_bedtools_when_chained_step_by_step( + duckdb_connection, +): + """Test chained GIQL pipeline matches bedtools at each step. + + Given: + Three interval sets across two chromosomes — A and B as inputs + for intersect + merge, C as reference for nearest — hand-crafted + so the pipeline output is unambiguous (no tie-breaking) + When: + Each GIQL step's output is materialized as a table and fed + back into the next GIQL step, and each bedtools equivalent + operates on its own prior step's output + Then: + GIQL and bedtools outputs should match at each of the three + stages: full row equality for intersect and merge, and + equal (a_name, b_name) neighbor pairs for nearest (distance + values are compared in the dedicated nearest tests because + bedtools 2.31+ uses the N+1 half-open gap convention) + """ + # Arrange + intervals_a = [ + GenomicInterval("chr1", 100, 300, "a1", 0, "+"), + GenomicInterval("chr1", 500, 700, "a2", 0, "+"), + GenomicInterval("chr2", 100, 300, "a3", 0, "+"), + ] + intervals_b = [ + GenomicInterval("chr1", 200, 400, "b1", 0, "+"), + GenomicInterval("chr1", 600, 800, "b2", 0, "+"), + GenomicInterval("chr2", 200, 400, "b3", 0, "+"), + ] + intervals_c = [ + GenomicInterval("chr1", 5000, 5100, "c1", 0, "+"), + GenomicInterval("chr2", 5000, 5100, "c2", 0, "+"), + ] + load_intervals( + duckdb_connection, "intervals_a", [i.to_tuple() for i in intervals_a] + ) + load_intervals( + duckdb_connection, "intervals_b", [i.to_tuple() for i in intervals_b] + ) + load_intervals( + duckdb_connection, "intervals_c", [i.to_tuple() for i in intervals_c] + ) + + # Act & Assert — Step 1: GIQL intersect vs bedtools intersect + bt_step1 = intersect( + [i.to_tuple() for i in intervals_a], + [i.to_tuple() for i in intervals_b], + ) + sql_step1 = transpile( + """ + SELECT DISTINCT a.* + FROM intervals_a a, intervals_b b + WHERE a.interval INTERSECTS b.interval + """, + tables=["intervals_a", "intervals_b"], + ) + giql_step1 = duckdb_connection.execute(sql_step1).fetchall() + c1 = compare_results(giql_step1, bt_step1) + assert c1.match, f"Step 1 (intersect): {c1.failure_message()}" + + # Act & Assert — Step 2: materialize GIQL step-1 output, GIQL MERGE + assert giql_step1, "fixture should produce at least one intersecting row" + load_intervals(duckdb_connection, "step1_results", giql_step1) + bt_step2 = merge(bt_step1) + sql_step2 = transpile( + "SELECT MERGE(interval) FROM step1_results", + tables=["step1_results"], + ) + giql_step2 = duckdb_connection.execute(sql_step2).fetchall() + c2 = compare_results(giql_step2, bt_step2) + assert c2.match, f"Step 2 (merge): {c2.failure_message()}" + + # Act & Assert — Step 3: pad BED3 step-2 output to BED6, GIQL NEAREST + assert giql_step2, "step 2 should produce at least one merged interval" + giql_step2_bed6 = [ + (row[0], row[1], row[2], f"step2_{i}", 0, "+") + for i, row in enumerate(giql_step2) + ] + load_intervals(duckdb_connection, "step2_results", giql_step2_bed6) + bt_step3 = closest( + giql_step2_bed6, [i.to_tuple() for i in intervals_c] + ) + sql_step3 = transpile( + """ + SELECT a.*, b.* + FROM step2_results a + CROSS JOIN LATERAL NEAREST(intervals_c, reference := a.interval) b + ORDER BY a.chrom, a.start + """, + tables=["step2_results", "intervals_c"], + ) + giql_step3 = duckdb_connection.execute(sql_step3).fetchall() + giql_pairs = {(row[3], row[9]) for row in giql_step3} + bt_pairs = {(row[3], row[9]) for row in bt_step3} + assert giql_pairs == bt_pairs, ( + f"Step 3 (nearest) neighbor pairs differ\n" + f" GIQL: {sorted(giql_pairs)}\n" + f" bedtools: {sorted(bt_pairs)}" + ) diff --git a/tests/integration/bedtools/test_distance.py b/tests/integration/bedtools/test_distance.py index 4fc53e7..628d14d 100644 --- a/tests/integration/bedtools/test_distance.py +++ b/tests/integration/bedtools/test_distance.py @@ -5,8 +5,12 @@ closest -d output. """ +import pytest + from .utils.data_models import GenomicInterval +pytestmark = pytest.mark.integration + def test_distance_non_overlapping(giql_query): """ diff --git a/tests/integration/bedtools/test_intersect.py b/tests/integration/bedtools/test_intersect.py index f4bfd43..c7434b1 100644 --- a/tests/integration/bedtools/test_intersect.py +++ b/tests/integration/bedtools/test_intersect.py @@ -4,6 +4,8 @@ results to bedtools intersect command. """ +import pytest + from giql import transpile from .utils.bedtools_wrapper import intersect @@ -11,6 +13,8 @@ from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + def test_intersect_basic_overlap(duckdb_connection): """ diff --git a/tests/integration/bedtools/test_intersect_property.py b/tests/integration/bedtools/test_intersect_property.py index a977547..1685e4d 100644 --- a/tests/integration/bedtools/test_intersect_property.py +++ b/tests/integration/bedtools/test_intersect_property.py @@ -6,6 +6,8 @@ bedtools intersect. """ +import pytest + from hypothesis import HealthCheck from hypothesis import given from hypothesis import settings @@ -18,6 +20,8 @@ from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + duckdb = __import__("pytest").importorskip("duckdb") diff --git a/tests/integration/bedtools/test_merge.py b/tests/integration/bedtools/test_merge.py index b9724c6..008d991 100644 --- a/tests/integration/bedtools/test_merge.py +++ b/tests/integration/bedtools/test_merge.py @@ -4,6 +4,8 @@ results to bedtools merge command. """ +import pytest + from giql import transpile from .utils.bedtools_wrapper import merge @@ -11,6 +13,8 @@ from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + def test_merge_adjacent_intervals(duckdb_connection): """ diff --git a/tests/integration/bedtools/test_nearest.py b/tests/integration/bedtools/test_nearest.py index 3a91641..80f11da 100644 --- a/tests/integration/bedtools/test_nearest.py +++ b/tests/integration/bedtools/test_nearest.py @@ -4,12 +4,16 @@ consistent with bedtools closest command. """ +import pytest + from giql import transpile from .utils.bedtools_wrapper import closest from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + def test_nearest_non_overlapping(duckdb_connection): """ diff --git a/tests/integration/bedtools/test_strand_aware.py b/tests/integration/bedtools/test_strand_aware.py index f9c8eb9..80f2cca 100644 --- a/tests/integration/bedtools/test_strand_aware.py +++ b/tests/integration/bedtools/test_strand_aware.py @@ -4,6 +4,8 @@ operations, matching bedtools behavior with -s and -S flags. """ +import pytest + from giql import transpile from .utils.bedtools_wrapper import closest @@ -13,6 +15,8 @@ from .utils.data_models import GenomicInterval from .utils.duckdb_loader import load_intervals +pytestmark = pytest.mark.integration + def test_intersect_same_strand(duckdb_connection): """ diff --git a/tests/integration/bedtools/test_within.py b/tests/integration/bedtools/test_within.py index f2935b5..fcb6037 100644 --- a/tests/integration/bedtools/test_within.py +++ b/tests/integration/bedtools/test_within.py @@ -5,8 +5,12 @@ equivalent exists, so tests validate against known expected results. """ +import pytest + from .utils.data_models import GenomicInterval +pytestmark = pytest.mark.integration + def test_within_basic(giql_query): """ diff --git a/tests/integration/bedtools/utils/duckdb_loader.py b/tests/integration/bedtools/utils/duckdb_loader.py index 286b543..6d443d8 100644 --- a/tests/integration/bedtools/utils/duckdb_loader.py +++ b/tests/integration/bedtools/utils/duckdb_loader.py @@ -24,4 +24,7 @@ def load_intervals(conn, table_name: str, intervals: list[tuple]) -> None: strand VARCHAR ) """) - conn.executemany(f"INSERT INTO {table_name} VALUES (?,?,?,?,?,?)", intervals) + if intervals: + conn.executemany( + f"INSERT INTO {table_name} VALUES (?,?,?,?,?,?)", intervals + ) diff --git a/tests/integration/bedtools/utils/random_intervals.py b/tests/integration/bedtools/utils/random_intervals.py new file mode 100644 index 0000000..d2d57cb --- /dev/null +++ b/tests/integration/bedtools/utils/random_intervals.py @@ -0,0 +1,42 @@ +"""Deterministic random-interval generator for bedtools integration tests.""" + +import random + +from .data_models import GenomicInterval + + +def generate_random_intervals( + *, + seed: int, + prefix: str, + count_per_chrom: int = 30, + n_chroms: int = 3, + start_max: int = 100_000, + min_size: int = 100, + max_size: int = 1000, + strand: str = "+", +) -> list[GenomicInterval]: + """Generate a deterministic list of GenomicInterval samples. + + Used by scale tests to produce realistic multi-chromosome input + sets without duplicating the same random-loop boilerplate. The + seed determines the exact sample — callers expecting identical + bedtools and GIQL outputs must pass the same seed to both sides. + """ + rng = random.Random(seed) + intervals: list[GenomicInterval] = [] + for chrom_num in range(1, n_chroms + 1): + for i in range(count_per_chrom): + start = rng.randint(0, start_max) + size = rng.randint(min_size, max_size) + intervals.append( + GenomicInterval( + f"chr{chrom_num}", + start, + start + size, + f"{prefix}_{chrom_num}_{i}", + 0, + strand, + ) + ) + return intervals diff --git a/tests/integration/bedtools/utils/test_bedtools_wrapper.py b/tests/integration/bedtools/utils/test_bedtools_wrapper.py new file mode 100644 index 0000000..72a83a3 --- /dev/null +++ b/tests/integration/bedtools/utils/test_bedtools_wrapper.py @@ -0,0 +1,626 @@ +"""Unit tests for pybedtools wrapper functions.""" + +import shutil + +import pytest + +pybedtools = pytest.importorskip("pybedtools") + +if not shutil.which("bedtools"): + pytest.skip( + "bedtools binary not found in PATH", + allow_module_level=True, + ) + +from .bedtools_wrapper import BedtoolsError # noqa: E402 +from .bedtools_wrapper import bedtool_to_tuples # noqa: E402 +from .bedtools_wrapper import closest # noqa: E402 +from .bedtools_wrapper import create_bedtool # noqa: E402 +from .bedtools_wrapper import intersect # noqa: E402 +from .bedtools_wrapper import merge # noqa: E402 + +pytestmark = pytest.mark.integration + + +def test_create_bedtool_should_parse_bed3(): + """Test that create_bedtool constructs a BedTool from BED3 tuples. + + Given: + A list of BED3 tuples + When: + create_bedtool() is called + Then: + It should return a BedTool with correct intervals + """ + # Arrange / Act + bt = create_bedtool([("chr1", 100, 200)]) + intervals = list(bt) + + # Assert + assert len(intervals) == 1 + assert intervals[0].chrom == "chr1" + assert intervals[0].start == 100 + assert intervals[0].end == 200 + + +def test_create_bedtool_should_parse_bed6(): + """Test that create_bedtool constructs a BedTool from BED6 tuples. + + Given: + A list of BED6 tuples + When: + create_bedtool() is called + Then: + It should return a BedTool with all 6 fields + """ + # Arrange / Act + bt = create_bedtool([("chr1", 100, 200, "a1", 50, "+")]) + intervals = list(bt) + + # Assert + assert len(intervals) == 1 + assert intervals[0].fields == ["chr1", "100", "200", "a1", "50", "+"] + + +def test_create_bedtool_should_replace_none_with_defaults(): + """Test that create_bedtool substitutes defaults for None values. + + Given: + BED6 tuples with None values + When: + create_bedtool() is called + Then: + It should replace None values with defaults + """ + # Arrange / Act + bt = create_bedtool([("chr1", 100, 200, None, None, None)]) + fields = list(bt)[0].fields + + # Assert + assert fields[3] == "." # name + assert fields[4] == "0" # score + assert fields[5] == "." # strand + + +def test_create_bedtool_should_raise_when_tuple_length_invalid(): + """Test that create_bedtool rejects tuples with wrong arity. + + Given: + A tuple with invalid length + When: + create_bedtool() is called + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="Invalid interval format"): + create_bedtool([("chr1", 100)]) + + +def test_create_bedtool_should_accept_multiple_intervals(): + """Test that create_bedtool handles multiple intervals across chromosomes. + + Given: + Multiple intervals across chromosomes + When: + create_bedtool() is called + Then: + It should return a BedTool containing all intervals + """ + # Arrange / Act + bt = create_bedtool( + [ + ("chr1", 100, 200, "a", 0, "+"), + ("chr2", 300, 400, "b", 0, "-"), + ] + ) + intervals = list(bt) + + # Assert + assert len(intervals) == 2 + + +def test_intersect_should_return_overlapping_intervals(): + """Test that intersect returns A intervals overlapping B. + + Given: + Two sets of overlapping intervals + When: + intersect() is called + Then: + It should return intervals from A that overlap B + """ + # Arrange + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 150, 250, "b1", 100, "+")] + + # Act + result = intersect(a, b) + + # Assert + assert len(result) == 1 + assert result[0][0] == "chr1" + + +def test_intersect_should_return_empty_when_no_overlap(): + """Test that intersect returns no rows when intervals disjoint. + + Given: + Non-overlapping intervals + When: + intersect() is called + Then: + It should return an empty list + """ + # Arrange + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 300, 400, "b1", 100, "+")] + + # Act + result = intersect(a, b) + + # Assert + assert result == [] + + +def test_intersect_should_filter_same_strand_only_when_strand_mode_same(): + """Test that intersect in same-strand mode keeps only same-strand hits. + + Given: + Intervals on same and opposite strands + When: + intersect() is called with strand_mode="same" + Then: + It should return only same-strand overlaps + """ + # Arrange + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr1", 100, 200, "a2", 0, "-"), + ] + b = [("chr1", 150, 250, "b1", 0, "+")] + + # Act + result = intersect(a, b, strand_mode="same") + names = [r[3] for r in result] + + # Assert + assert "a1" in names + assert "a2" not in names + + +def test_intersect_should_filter_opposite_strand_only_when_strand_mode_opposite(): + """Test that intersect in opposite-strand mode keeps only opposite-strand hits. + + Given: + Intervals on same and opposite strands + When: + intersect() is called with strand_mode="opposite" + Then: + It should return only opposite-strand overlaps + """ + # Arrange + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr1", 100, 200, "a2", 0, "-"), + ] + b = [("chr1", 150, 250, "b1", 0, "+")] + + # Act + result = intersect(a, b, strand_mode="opposite") + names = [r[3] for r in result] + + # Assert + assert "a2" in names + assert "a1" not in names + + +def test_intersect_should_ignore_strand_when_strand_mode_none(): + """Test that intersect ignores strand when strand_mode is None. + + Given: + Overlapping intervals on different strands + When: + intersect() is called with strand_mode=None + Then: + It should return all overlaps regardless of strand + """ + # Arrange + a = [("chr1", 100, 200, "a1", 0, "+")] + b = [("chr1", 150, 250, "b1", 0, "-")] + + # Act + result = intersect(a, b) + + # Assert + assert len(result) == 1 + + +def test_merge_should_combine_overlapping_intervals(): + """Test that merge collapses overlapping intervals into one. + + Given: + Overlapping intervals + When: + merge() is called + Then: + It should return merged BED3 intervals + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 150, 250, "i2", 0, "+"), + ] + + # Act + result = merge(intervals) + + # Assert + assert len(result) == 1 + assert result[0] == ("chr1", 100, 250) + + +def test_merge_should_preserve_separated_intervals(): + """Test that merge keeps non-overlapping intervals separate. + + Given: + Separated intervals + When: + merge() is called + Then: + It should return each interval separately as BED3 + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 300, 400, "i2", 0, "+"), + ] + + # Act + result = merge(intervals) + + # Assert + assert len(result) == 2 + + +def test_merge_should_merge_per_strand_when_strand_mode_same(): + """Test that merge segregates intervals by strand in same-strand mode. + + Given: + Overlapping intervals on different strands + When: + merge() is called with strand_mode="same" + Then: + It should merge per-strand separately + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 150, 250, "i2", 0, "+"), + ("chr1", 120, 220, "i3", 0, "-"), + ] + + # Act + result = merge(intervals, strand_mode="same") + + # Assert + # Should have 2: one merged + strand, one - strand + assert len(result) == 2 + + +def test_merge_should_combine_adjacent_intervals(): + """Test that merge joins intervals where one ends at the next's start. + + Given: + Adjacent intervals (end == start of next) + When: + merge() is called + Then: + It should merge adjacent intervals + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "i1", 0, "+"), + ("chr1", 200, 300, "i2", 0, "+"), + ] + + # Act + result = merge(intervals) + + # Assert + assert len(result) == 1 + assert result[0] == ("chr1", 100, 300) + + +def test_closest_should_pair_a_with_nearest_b_and_distance(): + """Test that closest pairs each A interval with the nearest B and a distance. + + Given: + Non-overlapping intervals + When: + closest() is called + Then: + It should return each A paired with nearest B plus distance + """ + # Arrange + a = [("chr1", 100, 200, "a1", 100, "+")] + b = [("chr1", 300, 400, "b1", 100, "+")] + + # Act + result = closest(a, b) + + # Assert + assert len(result) == 1 + # bedtools >= 2.31 reports N+1 for an N-base half-open gap between + # intervals (here 300 - 200 = 100, so expected distance is 101). + # The project pins bedtools >= 2.31.0 via pixi. + assert result[0][-1] == 101 + + +def test_closest_should_match_per_chromosome(): + """Test that closest restricts neighbor search to the same chromosome. + + Given: + Intervals on different chromosomes + When: + closest() is called + Then: + It should find the nearest per-chromosome + """ + # Arrange + a = [ + ("chr1", 100, 200, "a1", 0, "+"), + ("chr2", 100, 200, "a2", 0, "+"), + ] + b = [ + ("chr1", 300, 400, "b1", 0, "+"), + ("chr2", 500, 600, "b2", 0, "+"), + ] + + # Act + result = closest(a, b) + + # Assert + assert len(result) == 2 + # Each A should match B on same chromosome + for row in result: + assert row[0] == row[6] # a.chrom == b.chrom + + +def test_closest_should_return_nearest_same_strand_when_strand_mode_same(): + """Test that closest in same-strand mode picks the nearest same-strand B. + + Given: + Intervals with mixed strands + When: + closest() is called with strand_mode="same" + Then: + It should return the nearest same-strand interval + """ + # Arrange + a = [("chr1", 100, 200, "a1", 0, "+")] + b = [ + ("chr1", 220, 240, "b_opp", 0, "-"), # closer but opposite + ("chr1", 300, 400, "b_same", 0, "+"), # farther but same + ] + + # Act + result = closest(a, b, strand_mode="same") + + # Assert + assert len(result) == 1 + assert result[0][9] == "b_same" + + +def test_closest_should_return_k_neighbors(): + """Test that closest returns up to k nearest neighbors when k > 1. + + Given: + One query and three database intervals + When: + closest() is called with k=3 + Then: + It should return up to 3 nearest + """ + # Arrange + a = [("chr1", 200, 300, "a1", 0, "+")] + b = [ + ("chr1", 100, 150, "b1", 0, "+"), + ("chr1", 350, 400, "b2", 0, "+"), + ("chr1", 500, 600, "b3", 0, "+"), + ] + + # Act + result = closest(a, b, k=3) + + # Assert + # bedtools 2.31 with -t first collapses tied-distance candidates + # (b1 and b2 are both distance 51 from a1), so k=3 returns 2 rows + # for this specific fixture rather than 3. + assert len(result) == 2 + + +def test_bedtool_to_tuples_should_parse_bed3(): + """Test that bedtool_to_tuples converts BED3 intervals to 3-tuples. + + Given: + A BedTool with BED3 intervals + When: + bedtool_to_tuples() is called with bed_format="bed3" + Then: + It should return a list of (chrom, start, end) tuples with int positions + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="bed3") + + # Assert + assert result == [("chr1", 100, 200)] + + +def test_bedtool_to_tuples_should_parse_bed6(): + """Test that bedtool_to_tuples converts BED6 intervals to 6-tuples. + + Given: + A BedTool with BED6 intervals + When: + bedtool_to_tuples() is called with bed_format="bed6" + Then: + It should return a list of 6-tuples with correct types + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\tgene1\t500\t+\n", from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="bed6") + + # Assert + assert result == [("chr1", 100, 200, "gene1", 500, "+")] + + +def test_bedtool_to_tuples_should_convert_dot_to_none_for_bed6(): + """Test that bedtool_to_tuples maps "." placeholders to None in BED6. + + Given: + A BedTool with "." for name and strand + When: + bedtool_to_tuples() is called with bed_format="bed6" + Then: + It should convert "." values to None + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\t.\t0\t.\n", from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="bed6") + + # Assert + assert result[0][3] is None # name + assert result[0][5] is None # strand + + +def test_bedtool_to_tuples_should_pad_missing_bed6_fields(): + """Test that bedtool_to_tuples pads short rows to 6 fields. + + Given: + A BedTool with fewer than 6 fields + When: + bedtool_to_tuples() is called with bed_format="bed6" + Then: + It should pad missing fields with defaults + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="bed6") + + # Assert + assert len(result) == 1 + assert len(result[0]) == 6 + + +def test_bedtool_to_tuples_should_parse_closest_format(): + """Test that bedtool_to_tuples parses the 13-field closest format. + + Given: + A BedTool from closest operation (13 fields) + When: + bedtool_to_tuples() is called with bed_format="closest" + Then: + It should return tuples with A fields, B fields, and distance + """ + # Arrange + line = "chr1\t100\t200\ta1\t50\t+\tchr1\t300\t400\tb1\t75\t+\t100\n" + bt = pybedtools.BedTool(line, from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="closest") + + # Assert + assert len(result) == 1 + row = result[0] + assert row[0] == "chr1" # a.chrom + assert row[1] == 100 # a.start (int) + assert row[6] == "chr1" # b.chrom + assert row[7] == 300 # b.start (int) + assert row[12] == 100 # distance (int) + + +def test_bedtool_to_tuples_should_convert_dot_to_none_for_closest(): + """Test that bedtool_to_tuples maps "." placeholders to None in closest rows. + + Given: + A BedTool from closest with "." scores/names + When: + bedtool_to_tuples() is called with bed_format="closest" + Then: + It should convert "." values to None + """ + # Arrange + line = "chr1\t100\t200\t.\t.\t.\tchr1\t300\t400\t.\t.\t.\t50\n" + bt = pybedtools.BedTool(line, from_string=True) + + # Act + result = bedtool_to_tuples(bt, bed_format="closest") + + # Assert + row = result[0] + assert row[3] is None # a.name + assert row[4] is None # a.score + assert row[5] is None # a.strand + assert row[9] is None # b.name + + +def test_bedtool_to_tuples_should_raise_when_format_invalid(): + """Test that bedtool_to_tuples rejects unknown bed_format values. + + Given: + Any BedTool + When: + bedtool_to_tuples() is called with invalid format + Then: + It should raise ValueError + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\n", from_string=True) + + # Act / Assert + with pytest.raises(ValueError, match="Unsupported format"): + bedtool_to_tuples(bt, bed_format="invalid") + + +def test_bedtool_to_tuples_should_raise_when_closest_fields_insufficient(): + """Test that bedtool_to_tuples rejects closest rows with too few fields. + + Given: + A BedTool with fewer than 13 fields + When: + bedtool_to_tuples() is called with bed_format="closest" + Then: + It should raise ValueError + """ + # Arrange + bt = pybedtools.BedTool("chr1\t100\t200\ta1\t0\t+\n", from_string=True) + + # Act / Assert + with pytest.raises(ValueError, match="Unexpected number of fields"): + bedtool_to_tuples(bt, bed_format="closest") + + +class TestBedtoolsError: + def test___init___should_create_exception_with_message(self): + """Test that BedtoolsError behaves as an Exception carrying its message. + + Given: + A message string + When: + BedtoolsError is raised + Then: + It should be an instance of Exception with the correct message + """ + # Arrange / Act / Assert + with pytest.raises(BedtoolsError, match="test error"): + raise BedtoolsError("test error") diff --git a/tests/integration/bedtools/utils/test_comparison.py b/tests/integration/bedtools/utils/test_comparison.py new file mode 100644 index 0000000..d3fa74f --- /dev/null +++ b/tests/integration/bedtools/utils/test_comparison.py @@ -0,0 +1,389 @@ +"""Unit tests for result comparison logic.""" + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from .comparison import compare_results + +pytestmark = pytest.mark.integration + + +def test_compare_results_should_report_match_when_rows_identical(): + """Test that identical row lists compare as matching. + + Given: + Two identical lists of tuples + When: + compare_results() is called + Then: + It should return match=True with no differences + """ + # Arrange + rows = [("chr1", 100, 200), ("chr1", 300, 400)] + + # Act + result = compare_results(rows, rows) + + # Assert + assert result.match is True + assert result.differences == [] + + +def test_compare_results_should_match_when_rows_in_different_order(): + """Test that row order does not affect match outcome. + + Given: + Same tuples in different order + When: + compare_results() is called + Then: + It should return match=True + """ + # Arrange + a = [("chr1", 300, 400), ("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is True + + +def test_compare_results_should_report_mismatch_when_row_counts_differ(): + """Test that differing row counts produce a mismatch. + + Given: + Lists with different row counts + When: + compare_results() is called + Then: + It should return match=False with a row count difference + """ + # Arrange + a = [("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + assert any("Row count" in d for d in result.differences) + + +def test_compare_results_should_match_when_integer_values_identical(): + """Test that identical integer values compare as matching. + + Given: + Rows with identical integer values + When: + compare_results() is called + Then: + It should return match=True + """ + # Arrange + a = [("chr1", 100, 200, 50)] + b = [("chr1", 100, 200, 50)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is True + + +def test_compare_results_should_match_when_floats_within_epsilon(): + """Test that floats within default epsilon compare as matching. + + Given: + Rows with floats differing by less than epsilon + When: + compare_results() is called + Then: + It should return match=True + """ + # Arrange + a = [(1.0000000001,)] + b = [(1.0,)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is True + + +def test_compare_results_should_report_mismatch_when_floats_beyond_epsilon(): + """Test that floats beyond default epsilon produce a mismatch. + + Given: + Rows with floats differing by more than epsilon + When: + compare_results() is called + Then: + It should return match=False + """ + # Arrange + a = [(1.5,)] + b = [(1.0,)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + + +def test_compare_results_should_match_when_custom_epsilon_tolerates_difference(): + """Test that a larger custom epsilon accommodates small float deltas. + + Given: + Rows with floats differing by 0.05 + When: + compare_results() is called with epsilon=0.1 + Then: + It should return match=True + """ + # Arrange + a = [(1.05,)] + b = [(1.0,)] + + # Act + result = compare_results(a, b, epsilon=0.1) + + # Assert + assert result.match is True + + +def test_compare_results_should_match_when_none_values_align(): + """Test that aligned None values compare as matching. + + Given: + Rows with None in the same positions + When: + compare_results() is called + Then: + It should return match=True + """ + # Arrange + a = [("chr1", None, 200)] + b = [("chr1", None, 200)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is True + + +def test_compare_results_should_report_mismatch_when_none_vs_value(): + """Test that None paired with a concrete value produces a mismatch. + + Given: + Rows where one has None and the other has a value + When: + compare_results() is called + Then: + It should return match=False + """ + # Arrange + a = [("chr1", None, 200)] + b = [("chr1", 100, 200)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + + +def test_compare_results_should_report_mismatch_when_column_counts_differ(): + """Test that differing column counts produce a mismatch. + + Given: + Rows with different column counts + When: + compare_results() is called + Then: + It should return match=False with a column count difference + """ + # Arrange + a = [("chr1", 100, 200)] + b = [("chr1", 100)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + assert any("Column count" in d for d in result.differences) + + +def test_compare_results_should_list_extra_rows_when_giql_has_more(): + """Test that extra GIQL rows are reported in differences. + + Given: + GIQL has extra rows not in bedtools + When: + compare_results() is called + Then: + It should list the extra rows in differences + """ + # Arrange + a = [("chr1", 100, 200), ("chr1", 300, 400)] + b = [("chr1", 100, 200)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + assert any( + "missing in bedtools" in d.lower() or "Present in GIQL" in d + for d in result.differences + ) + + +def test_compare_results_should_list_missing_rows_when_bedtools_has_more(): + """Test that extra bedtools rows are reported as missing in GIQL. + + Given: + bedtools has extra rows not in GIQL + When: + compare_results() is called + Then: + It should list the missing rows in differences + """ + # Arrange + a = [("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", 300, 400)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is False + assert any("Missing in GIQL" in d for d in result.differences) + + +def test_compare_results_should_match_when_both_empty(): + """Test that two empty lists compare as matching with zero counts. + + Given: + Both lists empty + When: + compare_results() is called + Then: + It should return match=True with zero row counts + """ + # Arrange + # (no inputs to arrange beyond the empty lists passed below) + + # Act + result = compare_results([], []) + + # Assert + assert result.match is True + assert result.giql_row_count == 0 + assert result.bedtools_row_count == 0 + + +def test_compare_results_should_populate_metadata_with_epsilon_and_sorted(): + """Test that comparison metadata includes epsilon and sorted keys. + + Given: + Any comparison + When: + compare_results() is called + Then: + It should populate comparison_metadata with epsilon and sorted keys + """ + # Arrange + # (no inputs to arrange beyond the empty lists passed below) + + # Act + result = compare_results([], []) + + # Assert + assert "epsilon" in result.comparison_metadata + assert "sorted" in result.comparison_metadata + + +def test_compare_results_should_set_row_counts_when_sizes_differ(): + """Test that row counts are populated from the input list sizes. + + Given: + Lists of different sizes + When: + compare_results() is called + Then: + It should set giql_row_count and bedtools_row_count correctly + """ + # Arrange + # (inputs are supplied inline in the Act step) + + # Act + result = compare_results( + [("a",), ("b",)], + [("a",), ("b",), ("c",)], + ) + + # Assert + assert result.giql_row_count == 2 + assert result.bedtools_row_count == 3 + + +def test_compare_results_should_match_when_sorting_handles_none_values(): + """Test that sorting with None values completes without errors. + + Given: + Rows containing None values in different positions + When: + compare_results() is called + Then: + It should handle None deterministically and return match=True + """ + # Arrange + a = [("chr1", None, 200), ("chr1", 100, 200)] + b = [("chr1", 100, 200), ("chr1", None, 200)] + + # Act + result = compare_results(a, b) + + # Assert + assert result.match is True + + +@settings(max_examples=50) +@given( + rows=st.lists( + st.tuples( + st.sampled_from(["chr1", "chr2"]), + st.integers(min_value=0, max_value=10000), + st.integers(min_value=0, max_value=10000), + ), + min_size=0, + max_size=20, + ) +) +def test_compare_results_should_always_match_when_comparing_rows_to_themselves(rows): + """Test that self-comparison always yields a match. + + Given: + Any list of tuples + When: + compare_results(rows, rows) is called + Then: + It should always return match=True + """ + # Arrange + # (rows supplied by Hypothesis) + + # Act + result = compare_results(rows, rows) + + # Assert + assert result.match is True diff --git a/tests/integration/bedtools/utils/test_data_models.py b/tests/integration/bedtools/utils/test_data_models.py new file mode 100644 index 0000000..707a92c --- /dev/null +++ b/tests/integration/bedtools/utils/test_data_models.py @@ -0,0 +1,423 @@ +"""Unit tests for bedtools integration test data models.""" + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from .data_models import ComparisonResult +from .data_models import GenomicInterval + +pytestmark = pytest.mark.integration + + +class TestGenomicInterval: + def test___init___should_succeed_when_minimal_args_supplied(self): + """Test that minimal instantiation populates required fields and defaults. + + Given: + Valid chrom, start, end values + When: + GenomicInterval is instantiated + Then: + It should create an object with correct attributes and None defaults + """ + # Arrange / Act + gi = GenomicInterval("chr1", 100, 200) + + # Assert + assert gi.chrom == "chr1" + assert gi.start == 100 + assert gi.end == 200 + assert gi.name is None + assert gi.score is None + assert gi.strand is None + + def test___init___should_populate_optional_fields_when_supplied(self): + """Test that all fields are set when provided to the constructor. + + Given: + All fields provided + When: + GenomicInterval is instantiated + Then: + It should set all attributes correctly + """ + # Arrange / Act + gi = GenomicInterval("chrX", 500, 1000, "gene1", 800, "+") + + # Assert + assert gi.chrom == "chrX" + assert gi.start == 500 + assert gi.end == 1000 + assert gi.name == "gene1" + assert gi.score == 800 + assert gi.strand == "+" + + def test___post_init___should_raise_when_start_equals_end(self): + """Test that a zero-length interval is rejected. + + Given: + start equals end + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="start .* >= end"): + GenomicInterval("chr1", 200, 200) + + def test___post_init___should_raise_when_start_greater_than_end(self): + """Test that an inverted interval is rejected. + + Given: + start > end + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="start .* >= end"): + GenomicInterval("chr1", 300, 200) + + def test___post_init___should_raise_when_start_is_negative(self): + """Test that a negative start coordinate is rejected. + + Given: + start < 0 + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="start .* < 0"): + GenomicInterval("chr1", -1, 200) + + def test___post_init___should_raise_when_strand_is_invalid(self): + """Test that an invalid strand value is rejected. + + Given: + An invalid strand value + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="Invalid strand"): + GenomicInterval("chr1", 100, 200, strand="X") + + def test___post_init___should_raise_when_score_below_range(self): + """Test that a score below the BED range is rejected. + + Given: + score < 0 + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="Invalid score"): + GenomicInterval("chr1", 100, 200, score=-1) + + def test___post_init___should_raise_when_score_above_range(self): + """Test that a score above the BED range is rejected. + + Given: + score > 1000 + When: + GenomicInterval is instantiated + Then: + It should raise ValueError + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="Invalid score"): + GenomicInterval("chr1", 100, 200, score=1001) + + @pytest.mark.parametrize("strand", ["+", "-", "."]) + def test___post_init___should_accept_when_strand_is_valid(self, strand): + """Test that each allowed strand value is accepted. + + Given: + A valid strand value + When: + GenomicInterval is instantiated + Then: + It should create the object successfully + """ + # Arrange / Act + gi = GenomicInterval("chr1", 100, 200, strand=strand) + + # Assert + assert gi.strand == strand + + def test___post_init___should_accept_when_score_is_zero(self): + """Test that the lower boundary score is accepted. + + Given: + score = 0 + When: + GenomicInterval is instantiated + Then: + It should create the object successfully + """ + # Arrange / Act + gi = GenomicInterval("chr1", 100, 200, score=0) + + # Assert + assert gi.score == 0 + + def test___post_init___should_accept_when_score_is_thousand(self): + """Test that the upper boundary score is accepted. + + Given: + score = 1000 + When: + GenomicInterval is instantiated + Then: + It should create the object successfully + """ + # Arrange / Act + gi = GenomicInterval("chr1", 100, 200, score=1000) + + # Assert + assert gi.score == 1000 + + def test_to_tuple_should_return_all_fields_when_fully_populated(self): + """Test that to_tuple returns every field in order. + + Given: + A GenomicInterval with all fields + When: + to_tuple() is called + Then: + It should return a 6-element tuple with all field values + """ + # Arrange + gi = GenomicInterval("chr1", 100, 200, "a1", 500, "+") + + # Act + result = gi.to_tuple() + + # Assert + assert result == ("chr1", 100, 200, "a1", 500, "+") + + def test_to_tuple_should_include_none_when_optional_fields_missing(self): + """Test that to_tuple preserves None for unset optional fields. + + Given: + A GenomicInterval with optional fields as None + When: + to_tuple() is called + Then: + It should return a tuple containing None for optional fields + """ + # Arrange + gi = GenomicInterval("chr1", 100, 200) + + # Act + result = gi.to_tuple() + + # Assert + assert result == ("chr1", 100, 200, None, None, None) + + @settings(max_examples=50) + @given( + chrom=st.sampled_from(["chr1", "chr2", "chrX", "chrM"]), + start=st.integers(min_value=0, max_value=999_999), + size=st.integers(min_value=1, max_value=10_000), + strand=st.sampled_from(["+", "-", "."]), + score=st.integers(min_value=0, max_value=1000), + ) + def test_to_tuple_should_roundtrip_when_any_valid_interval( + self, chrom, start, size, strand, score + ): + """Test that to_tuple reflects the exact constructor inputs. + + Given: + Any valid GenomicInterval + When: + to_tuple() is called + Then: + It should return a tuple that matches the interval's key fields + """ + # Arrange + end = start + size + gi = GenomicInterval(chrom, start, end, "name", score, strand) + + # Act + t = gi.to_tuple() + + # Assert + assert t == (chrom, start, end, "name", score, strand) + + +class TestComparisonResult: + def test___init___should_populate_attributes_when_match_is_true(self): + """Test that a matching result stores its fields correctly. + + Given: + match=True with equal row counts + When: + ComparisonResult is instantiated + Then: + It should set attributes correctly with an empty differences list + """ + # Arrange / Act + cr = ComparisonResult(match=True, giql_row_count=5, bedtools_row_count=5) + + # Assert + assert cr.match is True + assert cr.giql_row_count == 5 + assert cr.bedtools_row_count == 5 + assert cr.differences == [] + + def test___init___should_populate_attributes_when_match_is_false(self): + """Test that a mismatching result stores its differences. + + Given: + match=False with differences + When: + ComparisonResult is instantiated + Then: + It should set attributes correctly including the differences list + """ + # Arrange + diffs = ["Row 0: mismatch"] + + # Act + cr = ComparisonResult( + match=False, + giql_row_count=3, + bedtools_row_count=4, + differences=diffs, + ) + + # Assert + assert cr.match is False + assert cr.differences == diffs + + def test___bool___should_return_true_when_match_is_true(self): + """Test truthiness of a matching result. + + Given: + A matching ComparisonResult + When: + Used in a boolean context + Then: + It should evaluate to True + """ + # Arrange + cr = ComparisonResult(match=True, giql_row_count=1, bedtools_row_count=1) + + # Act / Assert + assert cr + + def test___bool___should_return_false_when_match_is_false(self): + """Test falsiness of a non-matching result. + + Given: + A non-matching ComparisonResult + When: + Used in a boolean context + Then: + It should evaluate to False + """ + # Arrange + cr = ComparisonResult(match=False, giql_row_count=1, bedtools_row_count=2) + + # Act / Assert + assert not cr + + def test_failure_message_should_return_success_when_match_is_true(self): + """Test the message for a matching result. + + Given: + A matching ComparisonResult + When: + failure_message() is called + Then: + It should return a success message + """ + # Arrange + cr = ComparisonResult(match=True, giql_row_count=1, bedtools_row_count=1) + + # Act + msg = cr.failure_message() + + # Assert + assert "match" in msg.lower() + + def test_failure_message_should_include_counts_and_diffs_when_mismatch(self): + """Test the message formatting for a mismatching result. + + Given: + A non-matching ComparisonResult with differences + When: + failure_message() is called + Then: + It should return a formatted message with row counts and differences + """ + # Arrange + cr = ComparisonResult( + match=False, + giql_row_count=3, + bedtools_row_count=5, + differences=["Row 0: val mismatch", "Row 1: missing"], + ) + + # Act + msg = cr.failure_message() + + # Assert + assert "3" in msg + assert "5" in msg + assert "Row 0: val mismatch" in msg + assert "Row 1: missing" in msg + + def test_failure_message_should_truncate_when_over_ten_differences(self): + """Test that the message truncates the differences list at ten. + + Given: + A ComparisonResult with more than 10 differences + When: + failure_message() is called + Then: + It should show only the first 10 with a count of the remainder + """ + # Arrange + diffs = [f"diff_{i}" for i in range(15)] + cr = ComparisonResult( + match=False, + giql_row_count=0, + bedtools_row_count=15, + differences=diffs, + ) + + # Act + msg = cr.failure_message() + + # Assert + assert "diff_9" in msg + assert "diff_10" not in msg + assert "5 more" in msg + + def test___init___should_default_metadata_when_not_supplied(self): + """Test that comparison_metadata defaults to an empty dict. + + Given: + No comparison_metadata provided + When: + ComparisonResult is instantiated + Then: + It should default metadata to an empty dict + """ + # Arrange / Act + cr = ComparisonResult(match=True, giql_row_count=0, bedtools_row_count=0) + + # Assert + assert cr.comparison_metadata == {} diff --git a/tests/integration/bedtools/utils/test_duckdb_loader.py b/tests/integration/bedtools/utils/test_duckdb_loader.py new file mode 100644 index 0000000..d2bc230 --- /dev/null +++ b/tests/integration/bedtools/utils/test_duckdb_loader.py @@ -0,0 +1,120 @@ +"""Unit tests for DuckDB interval loading utility.""" + +import duckdb +import pytest + +from .duckdb_loader import load_intervals + +pytestmark = pytest.mark.integration + + +@pytest.fixture() +def conn(): + c = duckdb.connect(":memory:") + yield c + c.close() + + +def test_load_intervals_should_create_table_with_default_schema(conn): + """Test that load_intervals creates a table with the default GIQL schema. + + Given: + A DuckDB connection and a single interval tuple. + When: + load_intervals is called with a target table name. + Then: + It should create a table with columns chrom, start, end, name, score, strand. + """ + # Arrange, act, & assert + load_intervals(conn, "test_table", [("chr1", 100, 200, "a1", 50, "+")]) + cols = conn.execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name = 'test_table' ORDER BY ordinal_position" + ).fetchall() + col_names = [c[0] for c in cols] + assert col_names == ["chrom", "start", "end", "name", "score", "strand"] + + +def test_load_intervals_should_insert_all_tuples(conn): + """Test that load_intervals inserts every provided tuple. + + Given: + A DuckDB connection and multiple interval tuples. + When: + load_intervals is called and the resulting table is queried. + Then: + It should persist each tuple with its exact field values. + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "a1", 50, "+"), + ("chr2", 300, 400, "a2", 75, "-"), + ] + + # Act + load_intervals(conn, "t", intervals) + + # Assert + rows = conn.execute("SELECT * FROM t ORDER BY chrom").fetchall() + assert len(rows) == 2 + assert rows[0] == ("chr1", 100, 200, "a1", 50, "+") + assert rows[1] == ("chr2", 300, 400, "a2", 75, "-") + + +def test_load_intervals_should_store_nulls_when_optional_fields_are_none(conn): + """Test that load_intervals preserves None values for optional fields. + + Given: + A DuckDB connection and an interval tuple with None for name, score, and strand. + When: + load_intervals is called and the row is read back. + Then: + It should store the optional fields as SQL NULL values. + """ + # Arrange, act, & assert + load_intervals(conn, "t", [("chr1", 100, 200, None, None, None)]) + row = conn.execute("SELECT * FROM t").fetchone() + assert row == ("chr1", 100, 200, None, None, None) + + +def test_load_intervals_should_insert_all_rows_when_intervals_span_multiple_chromosomes(conn): + """Test that load_intervals loads intervals across different chromosomes. + + Given: + A DuckDB connection and interval tuples referencing chr1, chr2, and chrX. + When: + load_intervals is called with the cross-chromosome dataset. + Then: + It should insert every row regardless of its chromosome label. + """ + # Arrange + intervals = [ + ("chr1", 100, 200, "a", 0, "+"), + ("chr2", 100, 200, "b", 0, "+"), + ("chrX", 100, 200, "c", 0, "+"), + ] + + # Act + load_intervals(conn, "t", intervals) + + # Assert + count = conn.execute("SELECT COUNT(*) FROM t").fetchone()[0] + assert count == 3 + + +def test_load_intervals_should_create_empty_table_when_intervals_empty(conn): + """Test that load_intervals accepts an empty interval list. + + Given: + A DuckDB connection and an empty list of intervals. + When: + load_intervals is called with the empty list. + Then: + It should create the table with the default schema and zero rows. + """ + # Arrange, act + load_intervals(conn, "t", []) + + # Assert + count = conn.execute("SELECT COUNT(*) FROM t").fetchone()[0] + assert count == 0 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..bc36148 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for bedtools integration test utilities.""" diff --git a/tests/unit/test_dialect.py b/tests/unit/test_dialect.py new file mode 100644 index 0000000..b3a9593 --- /dev/null +++ b/tests/unit/test_dialect.py @@ -0,0 +1,394 @@ +"""Tests for giql.dialect module.""" + +from sqlglot import exp +from sqlglot import parse_one +from sqlglot.tokens import TokenType + +from giql.dialect import CONTAINS +from giql.dialect import INTERSECTS +from giql.dialect import WITHIN +from giql.dialect import GIQLDialect +from giql.expressions import Contains +from giql.expressions import GIQLCluster +from giql.expressions import GIQLRasterize +from giql.expressions import GIQLDistance +from giql.expressions import GIQLMerge +from giql.expressions import GIQLNearest +from giql.expressions import Intersects +from giql.expressions import SpatialPredicate +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within + + +class TestDialectConstants: + """Tests for module-level constants and token registration.""" + + def test_constants_should_equal_their_uppercase_names(self): + """Test module-level spatial-operator constants expose their uppercase names. + + Given: + The giql.dialect module is imported + When: + INTERSECTS, CONTAINS, and WITHIN constants are accessed + Then: + It should equal "INTERSECTS", "CONTAINS", and "WITHIN" respectively + """ + # DC-001 + # Arrange / Act / Assert + assert INTERSECTS == "INTERSECTS" + assert CONTAINS == "CONTAINS" + assert WITHIN == "WITHIN" + + def test_TokenType_should_expose_spatial_operator_attributes(self): + """Test that TokenType is extended with spatial-operator attributes. + + Given: + The giql.dialect module is imported + When: + TokenType attributes are checked for spatial operators + Then: + It should expose INTERSECTS, CONTAINS, and WITHIN attributes + """ + # DC-002 + # Arrange / Act / Assert + assert hasattr(TokenType, "INTERSECTS") + assert hasattr(TokenType, "CONTAINS") + assert hasattr(TokenType, "WITHIN") + + +class TestGIQLDialect: + """Tests for GIQLDialect parsing of spatial predicates and GIQL functions.""" + + def test_parse_one_should_produce_Intersects_node_for_intersects_predicate(self): + """Test parsing `column INTERSECTS 'chr1:1000-2000'` yields an Intersects node. + + Given: + A SELECT query containing `column INTERSECTS 'chr1:1000-2000'` + When: + The query is parsed with GIQLDialect + Then: + It should produce an Intersects node whose left side is the column + and whose right side is the literal range string + """ + # GD-001 + # Arrange + query = "SELECT * FROM t WHERE column INTERSECTS 'chr1:1000-2000'" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(Intersects)) + assert len(nodes) == 1 + node = nodes[0] + assert node.this.name == "column" + assert node.expression.this == "chr1:1000-2000" + + def test_parse_one_should_produce_Contains_node_for_contains_predicate(self): + """Test parsing `column CONTAINS 'chr1:1500'` yields a Contains node. + + Given: + A SELECT query containing `column CONTAINS 'chr1:1500'` + When: + The query is parsed with GIQLDialect + Then: + It should produce exactly one Contains node in the AST + """ + # GD-002 + # Arrange + query = "SELECT * FROM t WHERE column CONTAINS 'chr1:1500'" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(Contains)) + assert len(nodes) == 1 + + def test_parse_one_should_produce_Within_node_for_within_predicate(self): + """Test parsing `column WITHIN 'chr1:1000-5000'` yields a Within node. + + Given: + A SELECT query containing `column WITHIN 'chr1:1000-5000'` + When: + The query is parsed with GIQLDialect + Then: + It should produce exactly one Within node in the AST + """ + # GD-003 + # Arrange + query = "SELECT * FROM t WHERE column WITHIN 'chr1:1000-5000'" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(Within)) + assert len(nodes) == 1 + + def test_parse_one_should_set_quantifier_to_ANY_for_intersects_any(self): + """Test `INTERSECTS ANY(...)` produces a SpatialSetPredicate with quantifier ANY. + + Given: + A SELECT query containing `column INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')` + When: + The query is parsed with GIQLDialect + Then: + It should produce a SpatialSetPredicate whose quantifier argument is "ANY" + """ + # GD-004 + # Arrange + query = ( + "SELECT * FROM t WHERE column INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')" + ) + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(SpatialSetPredicate)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args["quantifier"] == "ANY" + + def test_parse_one_should_set_quantifier_to_ALL_for_intersects_all(self): + """Test `INTERSECTS ALL(...)` produces a SpatialSetPredicate with quantifier ALL. + + Given: + A SELECT query containing `column INTERSECTS ALL('chr1:1000-2000', 'chr1:5000-6000')` + When: + The query is parsed with GIQLDialect + Then: + It should produce a SpatialSetPredicate whose quantifier argument is "ALL" + """ + # GD-005 + # Arrange + query = ( + "SELECT * FROM t WHERE column INTERSECTS ALL('chr1:1000-2000', 'chr1:5000-6000')" + ) + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(SpatialSetPredicate)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args["quantifier"] == "ALL" + + def test_parse_one_should_produce_plain_select_when_no_spatial_operators_are_used(self): + """Test plain SQL parses without any spatial nodes under GIQLDialect. + + Given: + A SELECT query with no spatial operators + When: + The query is parsed with GIQLDialect + Then: + It should produce a standard Select AST containing no + SpatialPredicate or SpatialSetPredicate nodes + """ + # GD-006 + # Arrange + query = "SELECT id, name FROM t WHERE id = 1" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + spatial_nodes = list(ast.find_all(SpatialPredicate, SpatialSetPredicate)) + assert len(spatial_nodes) == 0 + assert ast.find(exp.Select) is not None + + def test_parse_one_should_produce_GIQLCluster_node_for_cluster_call(self): + """Test `CLUSTER(interval)` parses into a GIQLCluster AST node. + + Given: + A SELECT query containing `CLUSTER(interval)` + When: + The query is parsed with GIQLDialect + Then: + It should produce exactly one GIQLCluster node in the AST + """ + # GD-007 + # Arrange + query = "SELECT CLUSTER(interval) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + + def test_parse_one_should_set_distance_arg_on_GIQLCluster_when_distance_is_given(self): + """Test `CLUSTER(interval, 1000)` sets the distance argument on GIQLCluster. + + Given: + A SELECT query containing `CLUSTER(interval, 1000)` + When: + The query is parsed with GIQLDialect + Then: + It should produce a GIQLCluster node whose distance argument is set + """ + # GD-008 + # Arrange + query = "SELECT CLUSTER(interval, 1000) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("distance") is not None + + def test_parse_one_should_produce_GIQLMerge_node_for_merge_call(self): + """Test `MERGE(interval)` parses into a GIQLMerge AST node. + + Given: + A SELECT query containing `MERGE(interval)` + When: + The query is parsed with GIQLDialect + Then: + It should produce exactly one GIQLMerge node in the AST + """ + # GD-009 + # Arrange + query = "SELECT MERGE(interval) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + + def test_parse_one_should_set_resolution_arg_on_GIQLRasterize_when_resolution_is_positional(self): + """Test `RASTERIZE(interval, 1000)` sets the resolution argument on GIQLRasterize. + + Given: + A SELECT query containing `RASTERIZE(interval, 1000)` + When: + The query is parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node whose resolution argument is set + """ + # GD-010 + # Arrange + query = "SELECT RASTERIZE(interval, 1000) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLRasterize)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("resolution") is not None + + def test_parse_one_should_set_resolution_arg_on_GIQLRasterize_when_resolution_is_passed_as_kwarg(self): + """Test `RASTERIZE(interval, resolution => 1000)` sets resolution via Kwarg syntax. + + Given: + A SELECT query containing `RASTERIZE(interval, resolution => 1000)` + When: + The query is parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node whose resolution argument is set + """ + # GD-012 + # Arrange + query = "SELECT RASTERIZE(interval, resolution => 1000) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLRasterize)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("resolution") is not None + + def test_parse_one_should_produce_GIQLDistance_node_for_distance_call(self): + """Test `DISTANCE(a.interval, b.interval)` parses into a GIQLDistance AST node. + + Given: + A SELECT query containing `DISTANCE(a.interval, b.interval)` + When: + The query is parsed with GIQLDialect + Then: + It should produce exactly one GIQLDistance node in the AST + """ + # GD-014 + # Arrange + query = "SELECT DISTANCE(a.interval, b.interval) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + + def test_parse_one_should_set_k_arg_on_GIQLNearest_when_k_named_param_is_given(self): + """Test `NEAREST(genes, k := 3)` sets the k argument on GIQLNearest. + + Given: + A SELECT query containing `NEAREST(genes, k := 3)` + When: + The query is parsed with GIQLDialect + Then: + It should produce a GIQLNearest node whose k argument is set + """ + # GD-015 + # Arrange + query = "SELECT NEAREST(genes, k := 3) FROM t" + + # Act + ast = parse_one( + query, + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + node = nodes[0] + assert node.args.get("k") is not None diff --git a/tests/unit/test_expressions.py b/tests/unit/test_expressions.py new file mode 100644 index 0000000..4e25396 --- /dev/null +++ b/tests/unit/test_expressions.py @@ -0,0 +1,688 @@ +"""Tests for custom AST expression nodes. + +Test specification: specs/test_expressions.md +""" + +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expressions import Contains +from giql.expressions import GenomicRange +from giql.expressions import GIQLCluster +from giql.expressions import GIQLRasterize +from giql.expressions import GIQLDistance +from giql.expressions import GIQLMerge +from giql.expressions import GIQLNearest +from giql.expressions import Intersects +from giql.expressions import SpatialPredicate +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within + + +class TestGenomicRange: + """Tests for GenomicRange expression node.""" + + def test___init___should_succeed_when_required_args_supplied(self): + """Test GenomicRange instantiates with just required args. + + Given: + All required args (chromosome, start, end) + When: + GenomicRange is instantiated + Then: + It should have correct chromosome, start, and end args + """ + # Arrange + chrom = exp.Literal.string("chr1") + start = exp.Literal.number(1000) + end = exp.Literal.number(2000) + + # Act + gr = GenomicRange(chromosome=chrom, start=start, end=end) + + # Assert + assert gr.args["chromosome"] is chrom + assert gr.args["start"] is start + assert gr.args["end"] is end + + def test___init___should_accept_all_args_when_optional_supplied(self): + """Test GenomicRange instantiates with all optional args. + + Given: + Required args plus optional strand and coord_system + When: + GenomicRange is instantiated + Then: + It should have all five args accessible + """ + # Arrange + chrom = exp.Literal.string("chr1") + start = exp.Literal.number(1000) + end = exp.Literal.number(2000) + strand = exp.Literal.string("+") + coord_system = exp.Literal.string("0-based") + + # Act + gr = GenomicRange( + chromosome=chrom, + start=start, + end=end, + strand=strand, + coord_system=coord_system, + ) + + # Assert + assert gr.args["chromosome"] is chrom + assert gr.args["start"] is start + assert gr.args["end"] is end + assert gr.args["strand"] is strand + assert gr.args["coord_system"] is coord_system + + def test___init___should_default_optional_args_to_none_when_omitted(self): + """Test GenomicRange defaults optional args to None. + + Given: + Only required args provided + When: + GenomicRange is instantiated + Then: + It should leave strand and coord_system args as None + """ + # Act + gr = GenomicRange( + chromosome=exp.Literal.string("chr1"), + start=exp.Literal.number(1000), + end=exp.Literal.number(2000), + ) + + # Assert + assert gr.args.get("strand") is None + assert gr.args.get("coord_system") is None + + +class TestSpatialPredicate: + """Tests for SpatialPredicate subclasses.""" + + def test___init___should_produce_spatial_predicate_and_binary_when_intersects(self): + """Test Intersects inherits from SpatialPredicate and exp.Binary. + + Given: + Two expression nodes (this, expression) + When: + Intersects is instantiated + Then: + It should produce an instance of SpatialPredicate and exp.Binary + """ + # Arrange + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + # Act + node = Intersects(this=left, expression=right) + + # Assert + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + def test___init___should_produce_spatial_predicate_and_binary_when_contains(self): + """Test Contains inherits from SpatialPredicate and exp.Binary. + + Given: + Two expression nodes + When: + Contains is instantiated + Then: + It should produce an instance of SpatialPredicate and exp.Binary + """ + # Arrange + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + # Act + node = Contains(this=left, expression=right) + + # Assert + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + def test___init___should_produce_spatial_predicate_and_binary_when_within(self): + """Test Within inherits from SpatialPredicate and exp.Binary. + + Given: + Two expression nodes + When: + Within is instantiated + Then: + It should produce an instance of SpatialPredicate and exp.Binary + """ + # Arrange + left = exp.Column(this=exp.Identifier(this="a")) + right = exp.Column(this=exp.Identifier(this="b")) + + # Act + node = Within(this=left, expression=right) + + # Assert + assert isinstance(node, SpatialPredicate) + assert isinstance(node, exp.Binary) + + +class TestSpatialSetPredicate: + """Tests for SpatialSetPredicate expression node.""" + + def test___init___should_set_all_args_when_required_args_supplied(self): + """Test SpatialSetPredicate instantiates with all required args. + + Given: + All required args (this, operator, quantifier, ranges) + When: + SpatialSetPredicate is instantiated + Then: + It should have all four args accessible + """ + # Arrange + this = exp.Column(this=exp.Identifier(this="interval")) + operator = exp.Literal.string("INTERSECTS") + quantifier = exp.Literal.string("ANY") + ranges = exp.Array( + expressions=[ + exp.Literal.string("chr1:1000-2000"), + exp.Literal.string("chr1:5000-6000"), + ] + ) + + # Act + node = SpatialSetPredicate( + this=this, + operator=operator, + quantifier=quantifier, + ranges=ranges, + ) + + # Assert + assert node.args["this"] is this + assert node.args["operator"] is operator + assert node.args["quantifier"] is quantifier + assert node.args["ranges"] is ranges + + +class TestGIQLCluster: + """Tests for GIQLCluster expression node parsing.""" + + def test_parse_should_set_this_when_one_positional_arg(self): + """Test CLUSTER parses with a single positional arg. + + Given: + A CLUSTER expression with one positional arg (column) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCluster instance with `this` set + """ + # Act + ast = parse_one( + "SELECT CLUSTER(interval) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_should_set_distance_when_two_positional_args(self): + """Test CLUSTER parses with column and distance positionals. + + Given: + A CLUSTER expression with two positional args (column, distance) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCluster instance with `this` and `distance` set + """ + # Act + ast = parse_one( + "SELECT CLUSTER(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + + def test_parse_should_set_stranded_when_named_parameter_supplied(self): + """Test CLUSTER parses with a stranded named parameter. + + Given: + A CLUSTER expression with one positional and stranded := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCluster instance with `this` and `stranded` set + """ + # Act + ast = parse_one( + "SELECT CLUSTER(interval, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["stranded"] is not None + + def test_parse_should_set_distance_and_stranded_when_both_supplied(self): + """Test CLUSTER parses with both distance and stranded params. + + Given: + A CLUSTER expression with two positionals and stranded := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLCluster instance with `this`, `distance`, and `stranded` set + """ + # Act + ast = parse_one( + "SELECT CLUSTER(interval, 1000, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLCluster)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + assert nodes[0].args["stranded"] is not None + + def test___init___should_leave_optional_args_absent_when_only_this_supplied(self): + """Test GIQLCluster direct instantiation with just `this`. + + Given: + Required arg `this` only + When: + GIQLCluster is instantiated directly + Then: + It should set `this` and leave `distance` and `stranded` absent + """ + # Arrange + col = exp.Column(this=exp.Identifier(this="interval")) + + # Act + node = GIQLCluster(this=col) + + # Assert + assert node.args["this"] is col + assert node.args.get("distance") is None + assert node.args.get("stranded") is None + + +class TestGIQLMerge: + """Tests for GIQLMerge expression node parsing.""" + + def test_parse_should_set_this_when_one_positional_arg(self): + """Test MERGE parses with a single positional arg. + + Given: + A MERGE expression with one positional arg (column) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLMerge instance with `this` set + """ + # Act + ast = parse_one( + "SELECT MERGE(interval) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_should_set_distance_when_two_positional_args(self): + """Test MERGE parses with column and distance positionals. + + Given: + A MERGE expression with two positional args (column, distance) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLMerge instance with `this` and `distance` set + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + + def test_parse_should_set_stranded_when_named_parameter_supplied(self): + """Test MERGE parses with a stranded named parameter. + + Given: + A MERGE expression with one positional and stranded := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLMerge instance with `this` and `stranded` set + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["stranded"] is not None + + def test_parse_should_set_distance_and_stranded_when_both_supplied(self): + """Test MERGE parses with both distance and stranded params. + + Given: + A MERGE expression with two positionals and stranded := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLMerge instance with `this`, `distance`, and `stranded` set + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, 1000, stranded := true) FROM features", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLMerge)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["distance"].this == "1000" + assert nodes[0].args["stranded"] is not None + + +class TestGIQLRasterize: + """Tests for GIQLRasterize expression node parsing.""" + + # ------------------------------------------------------------------ + # Example-based parsing (COV-001 to COV-007) + # ------------------------------------------------------------------ + + def test_from_arg_list_should_map_resolution_when_positional(self): + """Test positional interval and resolution mapping. + + Given: + A RASTERIZE expression with positional interval and resolution + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node with resolution set + """ + # Act + ast = parse_one( + "SELECT RASTERIZE(interval, 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + rasterize = list(ast.find_all(GIQLRasterize)) + assert len(rasterize) == 1 + assert rasterize[0].args["resolution"].this == "1000" + + def test_from_arg_list_should_set_resolution_when_walrus_syntax(self): + """Test named resolution parameter via := syntax. + + Given: + A RASTERIZE expression with `resolution := 1000` + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node with resolution set via named param + """ + # Act + ast = parse_one( + "SELECT RASTERIZE(interval, resolution := 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + rasterize = list(ast.find_all(GIQLRasterize)) + assert len(rasterize) == 1 + assert rasterize[0].args["resolution"].this == "1000" + + def test_from_arg_list_should_set_resolution_when_arrow_syntax(self): + """Test named resolution parameter via => syntax. + + Given: + A RASTERIZE expression with `resolution => 1000` + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node with resolution set via named param + """ + # Act + ast = parse_one( + "SELECT RASTERIZE(interval, resolution => 1000) FROM features", + dialect=GIQLDialect, + ) + + # Assert + rasterize = list(ast.find_all(GIQLRasterize)) + assert len(rasterize) == 1 + assert rasterize[0].args["resolution"].this == "1000" + + # ------------------------------------------------------------------ + # Property-based parsing (PBT-001 to PBT-002) + # ------------------------------------------------------------------ + + @given(resolution=st.integers(min_value=1, max_value=10_000_000)) + @settings(max_examples=50, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_from_arg_list_should_set_resolution_when_positional(self, resolution): + """Test positional resolution parses correctly across the resolution range. + + Given: + Any valid resolution (1-10M) supplied positionally + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node with the matching resolution + """ + # Act + ast = parse_one( + f"SELECT RASTERIZE(interval, {resolution}) FROM features", + dialect=GIQLDialect, + ) + + # Assert + rasterize = list(ast.find_all(GIQLRasterize)) + assert len(rasterize) == 1 + assert rasterize[0].args["resolution"].this == str(resolution) + + @given( + resolution=st.integers(min_value=1, max_value=10_000_000), + syntax=st.sampled_from([":=", "=>"]), + ) + @settings(max_examples=50, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_from_arg_list_should_set_resolution_when_named_with_either_syntax( + self, resolution, syntax + ): + """Test named resolution parses correctly with either := or => syntax. + + Given: + Any valid resolution (1-10M) supplied with either `:=` or `=>` + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLRasterize node with the matching resolution + """ + # Act + ast = parse_one( + f"SELECT RASTERIZE(interval, resolution {syntax} {resolution}) FROM features", + dialect=GIQLDialect, + ) + + # Assert + rasterize = list(ast.find_all(GIQLRasterize)) + assert len(rasterize) == 1 + assert rasterize[0].args["resolution"].this == str(resolution) + + +class TestGIQLDistance: + """Tests for GIQLDistance expression node parsing.""" + + def test_parse_should_set_this_and_expression_when_two_positional_args(self): + """Test DISTANCE parses with two positional interval args. + + Given: + A DISTANCE expression with two positional args (interval_a, interval_b) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLDistance instance with `this` and `expression` set + """ + # Act + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) FROM a, b", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + + def test_parse_should_set_stranded_and_signed_when_both_named_params(self): + """Test DISTANCE parses with stranded and signed named params. + + Given: + A DISTANCE expression with two positionals and stranded := true, signed := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLDistance instance with `this`, `expression`, `stranded`, and `signed` set + """ + # Act + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true, signed := true) FROM a, b", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + assert nodes[0].args["stranded"] is not None + assert nodes[0].args["signed"] is not None + + def test_parse_should_leave_signed_absent_when_only_stranded_supplied(self): + """Test DISTANCE parses with only stranded named param. + + Given: + A DISTANCE expression with two positionals and only stranded := true + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLDistance instance with `this`, `expression`, and `stranded` set; `signed` absent + """ + # Act + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true) FROM a, b", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLDistance)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["expression"] is not None + assert nodes[0].args["stranded"] is not None + assert nodes[0].args.get("signed") is None + + +class TestGIQLNearest: + """Tests for GIQLNearest expression node parsing.""" + + def test_parse_should_set_this_when_one_positional_arg(self): + """Test NEAREST parses with a single positional table arg. + + Given: + A NEAREST expression with one positional arg (table) + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLNearest instance with `this` set + """ + # Act + ast = parse_one( + "SELECT NEAREST(genes) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + + def test_parse_should_set_k_when_named_parameter_supplied(self): + """Test NEAREST parses with a k named parameter. + + Given: + A NEAREST expression with one positional and k := 3 + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLNearest instance with `this` and `k` set + """ + # Act + ast = parse_one( + "SELECT NEAREST(genes, k := 3) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["k"].this == "3" + + def test_parse_should_set_all_args_when_multiple_named_params(self): + """Test NEAREST parses with multiple named params. + + Given: + A NEAREST expression with one positional and multiple named params + When: + Parsed with GIQLDialect + Then: + It should produce a GIQLNearest instance with all provided args set + """ + # Act + ast = parse_one( + "SELECT NEAREST(genes, k := 5, max_distance := 100000, stranded := true, signed := true) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + nodes = list(ast.find_all(GIQLNearest)) + assert len(nodes) == 1 + assert nodes[0].args["this"] is not None + assert nodes[0].args["k"].this == "5" + assert nodes[0].args["max_distance"].this == "100000" + assert nodes[0].args["stranded"] is not None + assert nodes[0].args["signed"] is not None diff --git a/tests/unit/test_generators_base.py b/tests/unit/test_generators_base.py new file mode 100644 index 0000000..7467c78 --- /dev/null +++ b/tests/unit/test_generators_base.py @@ -0,0 +1,600 @@ +"""Tests for BaseGIQLGenerator. + +Test specification: specs/test_generators_base.md +Test IDs: BG-001 through BG-020 +""" + +import pytest +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.generators import BaseGIQLGenerator +from giql.table import Table +from giql.table import Tables + + +@pytest.fixture +def tables_two(): + """Tables with two tables for column-to-column tests.""" + tables = Tables() + tables.register("features_a", Table("features_a")) + tables.register("features_b", Table("features_b")) + return tables + + +@pytest.fixture +def tables_peaks_and_genes(): + """Tables with peaks and genes for NEAREST/DISTANCE tests.""" + tables = Tables() + tables.register("peaks", Table("peaks")) + tables.register("genes", Table("genes")) + return tables + + +def _normalize(sql: str) -> str: + """Collapse whitespace for easier assertion.""" + return " ".join(sql.split()) + + +class TestBaseGIQLGenerator: + """Tests for BaseGIQLGenerator class (BG-001 to BG-020).""" + + # ------------------------------------------------------------------ + # Instantiation + # ------------------------------------------------------------------ + + def test___init___should_use_defaults_when_no_args(self): + """Test __init__ uses default state when no arguments are supplied. + + Given: + No arguments + When: + BaseGIQLGenerator is instantiated + Then: + It should have empty Tables and SUPPORTS_LATERAL set to True + """ + # Arrange / Act + generator = BaseGIQLGenerator() + + # Assert + assert generator.tables is not None + assert generator.SUPPORTS_LATERAL is True + # Empty tables: looking up any name returns None + assert generator.tables.get("anything") is None + + def test___init___should_use_provided_tables_when_given(self): + """Test __init__ adopts a caller-supplied Tables instance. + + Given: + A Tables instance with a registered table + When: + BaseGIQLGenerator is instantiated with tables= + Then: + It should use the provided tables for column resolution + """ + # Arrange + tables = Tables() + tables.register("peaks", Table("peaks")) + + # Act + generator = BaseGIQLGenerator(tables=tables) + + # Assert + assert generator.tables is tables + assert "peaks" in generator.tables + + # ------------------------------------------------------------------ + # Spatial predicates + # ------------------------------------------------------------------ + + def test_generate_should_emit_overlap_conditions_when_intersects_literal(self): + """Test generate emits overlap SQL for an INTERSECTS literal range. + + Given: + An Intersects AST node with a literal range 'chr1:1000-2000' + When: + generate is called + Then: + It should contain chrom = 'chr1' AND start < 2000 AND end > 1000 + """ + # Arrange + tables = Tables() + tables.register("peaks", Table("peaks")) + generator = BaseGIQLGenerator(tables=tables) + ast = parse_one( + "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1000-2000'", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "\"chrom\" = 'chr1'" in sql + assert '"start" < 2000' in sql + assert '"end" > 1000' in sql + + def test_generate_should_emit_qualified_overlap_when_intersects_column_to_column(self, tables_two): + """Test generate emits table-qualified overlap for column-to-column INTERSECTS. + + Given: + An Intersects AST node with column-to-column (a.interval INTERSECTS b.interval) + When: + generate is called + Then: + It should contain chrom equality and overlap conditions using both table prefixes + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_two) + ast = parse_one( + "SELECT * FROM features_a AS a CROSS JOIN features_b AS b " + "WHERE a.interval INTERSECTS b.interval", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert 'a."chrom" = b."chrom"' in sql + assert 'a."start" < b."end"' in sql + assert 'a."end" > b."start"' in sql + + def test_generate_should_emit_point_containment_when_contains_point(self): + """Test generate emits point containment SQL when CONTAINS targets a point. + + Given: + A Contains AST node with a point range 'chr1:1500' + When: + generate is called + Then: + It should contain point containment predicate + """ + # Arrange + generator = BaseGIQLGenerator() + ast = parse_one( + "SELECT * FROM peaks WHERE interval CONTAINS 'chr1:1500'", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "\"chrom\" = 'chr1'" in sql + assert '"start" <= 1500' in sql + assert '"end" > 1500' in sql + + def test_generate_should_emit_range_containment_when_contains_range(self): + """Test generate emits range containment SQL when CONTAINS targets a range. + + Given: + A Contains AST node with a range 'chr1:1000-2000' + When: + generate is called + Then: + It should contain range containment predicate + """ + # Arrange + generator = BaseGIQLGenerator() + ast = parse_one( + "SELECT * FROM peaks WHERE interval CONTAINS 'chr1:1000-2000'", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "\"chrom\" = 'chr1'" in sql + assert '"start" <= 1000' in sql + assert '"end" >= 2000' in sql + + def test_generate_should_emit_within_predicate_when_within_range(self): + """Test generate emits within-range SQL when the predicate is WITHIN. + + Given: + A Within AST node with a range 'chr1:1000-5000' + When: + generate is called + Then: + It should contain within predicate + """ + # Arrange + generator = BaseGIQLGenerator() + ast = parse_one( + "SELECT * FROM peaks WHERE interval WITHIN 'chr1:1000-5000'", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "\"chrom\" = 'chr1'" in sql + assert '"start" >= 1000' in sql + assert '"end" <= 5000' in sql + + # ------------------------------------------------------------------ + # Spatial set predicates + # ------------------------------------------------------------------ + + def test_generate_should_join_with_or_when_intersects_any(self): + """Test generate joins predicates with OR for INTERSECTS ANY. + + Given: + A SpatialSetPredicate with INTERSECTS ANY and two ranges + When: + generate is called + Then: + It should contain two conditions joined by OR + """ + # Arrange + generator = BaseGIQLGenerator() + ast = parse_one( + "SELECT * FROM peaks " + "WHERE interval INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert " OR " in sql + assert '"end" > 1000' in sql + assert '"end" > 5000' in sql + + def test_generate_should_join_with_and_when_intersects_all(self): + """Test generate joins predicates with AND for INTERSECTS ALL. + + Given: + A SpatialSetPredicate with INTERSECTS ALL and two ranges + When: + generate is called + Then: + It should contain two conditions joined by AND + """ + # Arrange + generator = BaseGIQLGenerator() + ast = parse_one( + "SELECT * FROM peaks " + "WHERE interval INTERSECTS ALL('chr1:1000-2000', 'chr1:1500-1800')", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + # The outer WHERE already has AND, but the set predicate wraps + # its conditions in parens joined by AND. + norm = _normalize(sql) + # Both range predicates should appear + assert '"start" < 2000' in sql + assert '"start" < 1800' in sql + # They are joined by AND (inside the set predicate parentheses) + # Check the pattern: one condition AND another condition + idx_first = norm.index('"start" < 2000') + idx_second = norm.index('"start" < 1800') + between = norm[idx_first:idx_second] + assert "AND" in between + + # ------------------------------------------------------------------ + # DISTANCE + # ------------------------------------------------------------------ + + def test_generate_should_emit_case_when_distance_basic(self, tables_two): + """Test generate emits a CASE WHEN expression for basic DISTANCE. + + Given: + A GIQLDistance node with two column references + When: + generate is called + Then: + It should contain CASE WHEN with chromosome check, overlap check, and distance calculations + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_two) + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert 'a."chrom" != b."chrom" THEN NULL' in sql + assert "THEN 0" in sql + assert 'b."start" - a."end"' in sql + assert 'a."start" - b."end"' in sql + assert sql.startswith("SELECT CASE WHEN") + + def test_generate_should_emit_strand_logic_when_distance_stranded(self, tables_two): + """Test generate emits strand NULL checks and flip logic when DISTANCE is stranded. + + Given: + A GIQLDistance node with stranded := true + When: + generate is called + Then: + It should contain strand NULL checks and strand flip logic + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_two) + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert 'a."strand" IS NULL' in sql + assert 'b."strand" IS NULL' in sql + assert "a.\"strand\" = '.'" in sql + assert "a.\"strand\" = '?'" in sql + assert "a.\"strand\" = '-'" in sql + + def test_generate_should_emit_signed_distance_when_distance_signed(self, tables_two): + """Test generate emits a negated upstream branch when DISTANCE is signed. + + Given: + A GIQLDistance node with signed := true + When: + generate is called + Then: + It should contain signed distance (negative for upstream) + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_two) + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, signed := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + # Signed: ELSE branch has negative sign + assert "-(" in sql + # Unsigned ELSE would be (a."start" - b."end") without negation + # Signed ELSE is -(a."start" - b."end") + assert '-(a."start" - b."end")' in sql + + def test_generate_should_combine_strand_and_sign_when_distance_stranded_and_signed(self, tables_two): + """Test generate combines strand flipping and signed output when both flags are set. + + Given: + A GIQLDistance node with stranded := true and signed := true + When: + generate is called + Then: + It should contain both strand flip and signed distance + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_two) + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval, stranded := true, signed := true) AS dist " + "FROM features_a a CROSS JOIN features_b b", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + # Should have strand NULL checks + assert 'a."strand" IS NULL' in sql + # Should have strand flip + assert "a.\"strand\" = '-'" in sql + # Stranded+signed: the ELSE for '-' strand flips sign differently + # from stranded-only + # In stranded+signed: ELSE WHEN strand='-' THEN (a.start - b.end) + # In stranded-only: ELSE WHEN strand='-' THEN -(a.start - b.end) + assert '(a."start" - b."end")' in sql + assert '-(a."start" - b."end")' in sql + + def test_generate_should_add_gap_adjustment_when_distance_uses_closed_intervals(self): + """Test generate adds a +1 gap adjustment for closed-interval DISTANCE. + + Given: + Tables with interval_type="closed" for one table + When: + generate is called for a DISTANCE expression + Then: + It should contain '+ 1' gap adjustment + """ + # Arrange + tables = Tables() + tables.register("bed_a", Table("bed_a", interval_type="closed")) + tables.register("bed_b", Table("bed_b", interval_type="closed")) + generator = BaseGIQLGenerator(tables=tables) + ast = parse_one( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM bed_a a CROSS JOIN bed_b b", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "+ 1)" in sql + + # ------------------------------------------------------------------ + # NEAREST + # ------------------------------------------------------------------ + + def test_generate_should_emit_order_by_and_limit_when_nearest_standalone(self, tables_peaks_and_genes): + """Test generate emits an ORDER BY / LIMIT subquery for standalone NEAREST. + + Given: + A GIQLNearest node with explicit reference (standalone mode) + When: + generate is called + Then: + It should produce a subquery with WHERE, ORDER BY ABS(distance), and LIMIT + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000')", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + norm = _normalize(sql) + + # Assert + assert "WHERE" in norm + assert "ORDER BY ABS(" in norm + assert "LIMIT 1" in norm + assert "'chr1' = genes.\"chrom\"" in sql + assert "AS distance" in sql + + def test_generate_should_limit_five_when_nearest_k_is_five(self, tables_peaks_and_genes): + """Test generate applies LIMIT 5 when NEAREST is given k := 5. + + Given: + A GIQLNearest node with k := 5 + When: + generate is called + Then: + It should produce LIMIT 5 + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 5)", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert "LIMIT 5" in sql + + def test_generate_should_embed_threshold_when_nearest_max_distance(self, tables_peaks_and_genes): + """Test generate embeds the max_distance threshold in the WHERE clause. + + Given: + A GIQLNearest node with max_distance := 100000 + When: + generate is called + Then: + It should place the distance threshold in the WHERE clause + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + ast = parse_one( + "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', max_distance := 100000)", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + norm = _normalize(sql) + + # Assert + assert "100000" in norm + assert "<= 100000" in norm + + def test_generate_should_reference_outer_columns_when_nearest_correlated_lateral(self, tables_peaks_and_genes): + """Test generate emits a LATERAL-compatible subquery referencing outer columns. + + Given: + A GIQLNearest node in correlated mode (no standalone reference, in LATERAL context) + When: + generate is called + Then: + It should produce a LATERAL-compatible subquery referencing the outer table columns + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + ast = parse_one( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3)", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + norm = _normalize(sql) + + # Assert + assert "LATERAL" in norm + assert 'peaks."chrom"' in sql + assert 'genes."chrom"' in sql + assert "LIMIT 3" in sql + + def test_generate_should_match_strand_when_nearest_stranded(self, tables_peaks_and_genes): + """Test generate includes strand matching in WHERE when NEAREST is stranded. + + Given: + A GIQLNearest node with stranded := true + When: + generate is called + Then: + It should include strand matching in the WHERE clause + """ + # Arrange + generator = BaseGIQLGenerator(tables=tables_peaks_and_genes) + ast = parse_one( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3, stranded := true)", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + assert 'peaks."strand"' in sql + assert 'genes."strand"' in sql + # Strand matching in WHERE + assert 'peaks."strand" = genes."strand"' in sql + + # ------------------------------------------------------------------ + # SELECT override + # ------------------------------------------------------------------ + + def test_generate_should_resolve_aliases_when_select_has_alias_mapping(self): + """Test generate resolves FROM/JOIN aliases to registered tables. + + Given: + A SELECT with aliased FROM and JOIN tables + When: + generate is called + Then: + It should build alias-to-table mapping correctly, verified through correct column resolution in a spatial op + """ + # Arrange + tables = Tables() + tables.register("features_a", Table("features_a")) + tables.register("features_b", Table("features_b")) + generator = BaseGIQLGenerator(tables=tables) + ast = parse_one( + "SELECT * FROM features_a AS a " + "JOIN features_b AS b ON a.id = b.id " + "WHERE a.interval INTERSECTS b.interval", + dialect=GIQLDialect, + ) + + # Act + sql = generator.generate(ast) + + # Assert + # The aliases 'a' and 'b' should resolve to the registered tables + # and produce correctly qualified column references + assert 'a."chrom" = b."chrom"' in sql + assert 'a."start" < b."end"' in sql + assert 'a."end" > b."start"' in sql diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py new file mode 100644 index 0000000..d8a17b1 --- /dev/null +++ b/tests/unit/test_table.py @@ -0,0 +1,333 @@ +"""Tests for giql.table module.""" + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from giql.table import Table +from giql.table import Tables + + +class TestTable: + """Tests for the Table dataclass.""" + + def test___init___should_use_default_values_when_only_name_provided(self): + """Test Table uses default values when only `name` is provided. + + Given: + Only the required arg `name` + When: + Table is instantiated + Then: + It should set all fields to their default values + """ + # Arrange / Act + table = Table(name="peaks") + + # Assert + assert table.name == "peaks" + assert table.genomic_col == "interval" + assert table.chrom_col == "chrom" + assert table.start_col == "start" + assert table.end_col == "end" + assert table.strand_col == "strand" + assert table.coordinate_system == "0based" + assert table.interval_type == "half_open" + + def test___init___should_reflect_custom_values_when_all_fields_provided(self): + """Test Table reflects custom values when all fields are provided. + + Given: + All fields provided with custom values + When: + Table is instantiated + Then: + It should populate all fields with the custom values + """ + # Arrange / Act + table = Table( + name="variants", + genomic_col="position", + chrom_col="chr", + start_col="pos_start", + end_col="pos_end", + strand_col="direction", + coordinate_system="1based", + interval_type="closed", + ) + + # Assert + assert table.name == "variants" + assert table.genomic_col == "position" + assert table.chrom_col == "chr" + assert table.start_col == "pos_start" + assert table.end_col == "pos_end" + assert table.strand_col == "direction" + assert table.coordinate_system == "1based" + assert table.interval_type == "closed" + + def test___init___should_allow_none_when_strand_col_is_none(self): + """Test Table allows strand_col to be None. + + Given: + strand_col=None + When: + Table is instantiated + Then: + It should set strand_col to None + """ + # Arrange / Act + table = Table(name="peaks", strand_col=None) + + # Assert + assert table.strand_col is None + + def test___init___should_accept_1based_when_coordinate_system_is_1based(self): + """Test Table accepts the 1based coordinate system. + + Given: + coordinate_system="1based" + When: + Table is instantiated + Then: + It should set coordinate_system to "1based" + """ + # Arrange / Act + table = Table(name="peaks", coordinate_system="1based") + + # Assert + assert table.coordinate_system == "1based" + + def test___init___should_accept_closed_when_interval_type_is_closed(self): + """Test Table accepts the closed interval type. + + Given: + interval_type="closed" + When: + Table is instantiated + Then: + It should set interval_type to "closed" + """ + # Arrange / Act + table = Table(name="peaks", interval_type="closed") + + # Assert + assert table.interval_type == "closed" + + def test___init___should_raise_when_coordinate_system_invalid(self): + """Test Table raises when coordinate_system is invalid. + + Given: + coordinate_system="invalid" + When: + Table is instantiated + Then: + It should raise ValueError mentioning coordinate_system + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="coordinate_system"): + Table(name="peaks", coordinate_system="invalid") + + def test___init___should_raise_when_interval_type_invalid(self): + """Test Table raises when interval_type is invalid. + + Given: + interval_type="invalid" + When: + Table is instantiated + Then: + It should raise ValueError mentioning interval_type + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="interval_type"): + Table(name="peaks", interval_type="invalid") + + @given( + coordinate_system=st.sampled_from(["0based", "1based"]), + interval_type=st.sampled_from(["half_open", "closed"]), + ) + @settings(max_examples=20) + def test___init___should_not_raise_when_params_are_valid( + self, coordinate_system, interval_type + ): + """Test Table never raises for any valid parameter combination. + + Given: + Any Table with valid coordinate_system and interval_type + When: + Table is instantiated + Then: + It should not raise and all fields should be accessible + """ + # Arrange / Act + table = Table( + name="test", + coordinate_system=coordinate_system, + interval_type=interval_type, + ) + + # Assert + assert table.coordinate_system == coordinate_system + assert table.interval_type == interval_type + + +class TestTables: + """Tests for the Tables container class.""" + + def test_get_should_return_none_when_name_absent(self): + """Test get returns None for an unregistered name. + + Given: + A fresh Tables instance + When: + get is called with an unregistered name + Then: + It should return None + """ + # Arrange + tables = Tables() + + # Act / Assert + assert tables.get("unknown") is None + + def test_get_should_return_table_when_name_registered(self): + """Test get returns the Table for a registered name. + + Given: + A Tables instance with one registered table + When: + get is called with the registered name + Then: + It should return the registered Table object + """ + # Arrange + tables = Tables() + table = Table(name="peaks") + tables.register("peaks", table) + + # Act / Assert + assert tables.get("peaks") is table + + def test_register_should_store_all_tables_when_called_multiple_times(self): + """Test register stores every table when called with distinct names. + + Given: + A Tables instance with one registered table + When: + register is called with a new name and Table + Then: + It should make both tables retrievable via get + """ + # Arrange + tables = Tables() + peaks = Table(name="peaks") + variants = Table(name="variants") + tables.register("peaks", peaks) + tables.register("variants", variants) + + # Act / Assert + assert tables.get("peaks") is peaks + assert tables.get("variants") is variants + + def test_register_should_overwrite_when_name_already_registered(self): + """Test register overwrites an existing entry with the same name. + + Given: + A Tables instance with a registered table + When: + register is called with the same name and a different Table + Then: + It should make get return the new Table + """ + # Arrange + tables = Tables() + old_table = Table(name="peaks") + new_table = Table(name="peaks", chrom_col="chr") + tables.register("peaks", old_table) + tables.register("peaks", new_table) + + # Act / Assert + assert tables.get("peaks") is new_table + + def test___contains___should_return_true_when_name_registered(self): + """Test __contains__ returns True for a registered name. + + Given: + A Tables instance with registered tables + When: + the in operator is used with a registered name + Then: + It should return True + """ + # Arrange + tables = Tables() + tables.register("peaks", Table(name="peaks")) + + # Act / Assert + assert "peaks" in tables + + def test___contains___should_return_false_when_name_absent(self): + """Test __contains__ returns False for an unregistered name. + + Given: + A Tables instance with registered tables + When: + the in operator is used with an unregistered name + Then: + It should return False + """ + # Arrange + tables = Tables() + tables.register("peaks", Table(name="peaks")) + + # Act / Assert + assert "unknown" not in tables + + def test___iter___should_yield_all_registered(self): + """Test __iter__ yields all registered Table objects. + + Given: + A Tables instance with registered tables + When: + iterated with a for loop + Then: + It should yield all registered Table objects + """ + # Arrange + tables = Tables() + peaks = Table(name="peaks") + variants = Table(name="variants") + tables.register("peaks", peaks) + tables.register("variants", variants) + + # Act + result = [] + for table in tables: + result.append(table) + + # Assert + assert len(result) == 2 + assert peaks in result + assert variants in result + + def test___iter___should_yield_nothing_when_empty(self): + """Test __iter__ yields nothing when no tables are registered. + + Given: + A fresh Tables instance with no tables + When: + iterated with a for loop + Then: + It should yield nothing + """ + # Arrange + tables = Tables() + + # Act + result = [] + for table in tables: + result.append(table) + + # Assert + assert result == [] diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py new file mode 100644 index 0000000..ddfe6d8 --- /dev/null +++ b/tests/unit/test_transformer.py @@ -0,0 +1,1197 @@ +"""Tests for the transformer module. + +Test specification: specs/test_transformer.md +""" + +import duckdb +import pytest +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st +from sqlglot import exp +from sqlglot import parse_one + +from giql import Table +from giql import transpile +from giql.dialect import GIQLDialect +from giql.generators import BaseGIQLGenerator +from giql.table import Tables +from giql.transformer import ClusterTransformer +from giql.transformer import RasterizeTransformer +from giql.transformer import MergeTransformer + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tables(*names: str, **custom: Table) -> Tables: + tables = Tables() + for name in names: + tables.register(name, Table(name)) + for name, table in custom.items(): + tables.register(name, table) + return tables + + +def _transpile_with_transformer( + query: str, transformer_cls, tables: Tables | None = None +) -> str: + """Run the full parse-transform-generate pipeline for SQL-substring assertions. + + Returned SQL reflects the composition of parser, ``transformer_cls``, + and :class:`BaseGIQLGenerator`. Tests that assert on SQL output are + exercising the end-to-end transpilation contract; if one of them + fails, check all three stages to localise the regression rather + than assuming the transformer is at fault. + """ + tables = tables or _make_tables("features") + ast = parse_one(query, dialect=GIQLDialect) + transformer = transformer_cls(tables) + result = transformer.transform(ast) + generator = BaseGIQLGenerator(tables=tables) + return generator.generate(result) + + +# =========================================================================== +# TestClusterTransformer +# =========================================================================== + + +class TestClusterTransformer: + """Tests for ClusterTransformer.transform.""" + + def test_transform_should_produce_lag_and_sum_windows_when_basic_cluster(self): + """Test basic CLUSTER produces LAG and SUM window expressions. + + Given: + A Tables instance and a parsed SELECT with CLUSTER(interval) + When: + transform is called + Then: + It should produce a result containing LAG and SUM window expressions + """ + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval) FROM features", ClusterTransformer + ) + + # Assert + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_transform_should_preserve_alias_when_cluster_has_alias(self): + """Test CLUSTER alias is preserved on the SUM window expression. + + Given: + A parsed SELECT with CLUSTER(interval) AS cluster_id + When: + transform is called + Then: + It should preserve the alias on the SUM window expression + """ + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval) AS cluster_id FROM features", + ClusterTransformer, + ) + + # Assert + assert "cluster_id" in sql + + def test_transform_should_include_distance_when_cluster_has_distance(self): + """Test CLUSTER with distance adds the distance to the LAG result. + + Given: + A parsed SELECT with CLUSTER(interval, 1000) + When: + transform is called + Then: + It should add distance 1000 to the LAG result + """ + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval, 1000) FROM features", + ClusterTransformer, + ) + + # Assert + upper = sql.upper() + assert "LAG" in upper + assert "1000" in sql + + def test_transform_should_partition_by_strand_when_stranded(self): + """Test stranded CLUSTER partitions by chrom AND strand. + + Given: + A parsed SELECT with CLUSTER(interval, stranded := true) + When: + transform is called + Then: + It should partition by chrom AND strand + """ + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval, stranded := true) FROM features", + ClusterTransformer, + ) + + # Assert + upper = sql.upper() + assert "STRAND" in upper + # Both chrom and strand should appear in partition + assert "CHROM" in upper + + def test_transform_should_return_unchanged_when_expression_is_not_select(self): + """Test non-SELECT expression passes through unchanged. + + Given: + A non-SELECT expression + When: + transform is called + Then: + It should return the expression unchanged + """ + # Arrange + tables = _make_tables("features") + transformer = ClusterTransformer(tables) + insert = exp.Insert(this=exp.to_table("features")) + + # Act + result = transformer.transform(insert) + + # Assert + assert result is insert + + def test_transform_should_return_unchanged_when_no_cluster(self): + """Test SELECT without CLUSTER passes through unchanged. + + Given: + A SELECT with no CLUSTER expressions + When: + transform is called + Then: + It should return the query unchanged + """ + # Arrange + tables = _make_tables("features") + transformer = ClusterTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + + # Act + result = transformer.transform(ast) + + # Assert + assert result is ast + + def test_transform_should_use_custom_column_names_when_tables_configured(self): + """Test custom column names from Tables propagate into output SQL. + + Given: + A Tables instance with custom column names + When: + transform is called on a CLUSTER query + Then: + The generated query should use the custom column names + """ + # Arrange + custom = Table( + "features", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + tables = _make_tables(features=custom) + + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval) FROM features", + ClusterTransformer, + tables=tables, + ) + + # Assert + assert "chromosome" in sql + assert "start_pos" in sql + assert "end_pos" in sql + + def test_transform_should_recurse_when_cluster_inside_cte(self): + """Test CLUSTER inside a CTE subquery is recursively transformed. + + Given: + A SELECT with CLUSTER inside a CTE subquery + When: + transform is called + Then: + It should recursively transform the CTE subquery + """ + # Act + sql = _transpile_with_transformer( + "WITH c AS (SELECT *, CLUSTER(interval) AS cid FROM features) " + "SELECT * FROM c", + ClusterTransformer, + ) + + # Assert + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_transform_should_preserve_where_when_cluster_has_where(self): + """Test WHERE clause is preserved alongside CLUSTER. + + Given: + A SELECT with CLUSTER and a WHERE clause + When: + transform is called + Then: + It should preserve the WHERE clause + """ + # Act + sql = _transpile_with_transformer( + "SELECT *, CLUSTER(interval) FROM features WHERE score > 10", + ClusterTransformer, + ) + + # Assert + assert "score > 10" in sql + + def test_transform_should_add_required_genomic_columns_when_specific_columns(self): + """Test specific column selection adds required genomic cols to CTE. + + Given: + A SELECT with specific columns (not *) and CLUSTER + When: + transform is called + Then: + It should add missing required genomic columns to the CTE select list + """ + # Act + sql = _transpile_with_transformer( + "SELECT name, CLUSTER(interval) AS cid FROM features", + ClusterTransformer, + ) + + # Assert + upper = sql.upper() + # Required genomic cols should be in the output + assert "CHROM" in upper + assert "START" in upper + assert "END" in upper + + +# =========================================================================== +# TestMergeTransformer +# =========================================================================== + + +class TestMergeTransformer: + """Tests for MergeTransformer.transform.""" + + def test_transform_should_produce_group_by_min_max_when_basic_merge(self): + """Test basic MERGE produces GROUP BY with MIN(start) and MAX(end). + + Given: + A Tables instance and a parsed SELECT with MERGE(interval) + When: + transform is called + Then: + It should produce a result with GROUP BY, MIN(start), MAX(end) + """ + # Act + sql = _transpile_with_transformer( + "SELECT MERGE(interval) FROM features", MergeTransformer + ) + + # Assert + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + def test_transform_should_produce_fixed_columns_when_merge_has_alias(self): + """Test MERGE alias is dropped but output still has fixed columns. + + Given: + A parsed SELECT with MERGE(interval) AS merged + When: + transform is called + Then: + It should still produce valid output with fixed columns + """ + # Act + sql = _transpile_with_transformer( + "SELECT MERGE(interval) AS merged FROM features", + MergeTransformer, + ) + + # Assert + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + def test_transform_should_pass_distance_when_merge_has_distance(self): + """Test MERGE with distance passes the distance through to CLUSTER. + + Given: + A parsed SELECT with MERGE(interval, 1000) + When: + transform is called + Then: + It should pass the distance through to CLUSTER + """ + # Act + sql = _transpile_with_transformer( + "SELECT MERGE(interval, 1000) FROM features", + MergeTransformer, + ) + + # Assert + assert "1000" in sql + + def test_transform_should_add_strand_to_group_by_when_stranded(self): + """Test stranded MERGE adds strand to GROUP BY and partition. + + Given: + A parsed SELECT with MERGE(interval, stranded := true) + When: + transform is called + Then: + strand should appear in GROUP BY and partition + """ + # Act + sql = _transpile_with_transformer( + "SELECT MERGE(interval, stranded := true) FROM features", + MergeTransformer, + ) + + # Assert + upper = sql.upper() + assert "STRAND" in upper + assert "GROUP BY" in upper + + def test_transform_should_return_unchanged_when_expression_is_not_select(self): + """Test non-SELECT expression passes through unchanged. + + Given: + A non-SELECT expression + When: + transform is called + Then: + It should return the expression unchanged + """ + # Arrange + tables = _make_tables("features") + transformer = MergeTransformer(tables) + insert = exp.Insert(this=exp.to_table("features")) + + # Act + result = transformer.transform(insert) + + # Assert + assert result is insert + + def test_transform_should_return_unchanged_when_no_merge(self): + """Test SELECT without MERGE passes through unchanged. + + Given: + A SELECT with no MERGE expressions + When: + transform is called + Then: + It should return the query unchanged + """ + # Arrange + tables = _make_tables("features") + transformer = MergeTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + + # Act + result = transformer.transform(ast) + + # Assert + assert result is ast + + def test_transform_should_raise_when_multiple_merge_expressions(self): + """Test two MERGE expressions raise ValueError. + + Given: + A SELECT with two MERGE expressions + When: + transform is called + Then: + It should raise ValueError + """ + # Arrange + tables = _make_tables("features") + transformer = MergeTransformer(tables) + ast = parse_one( + "SELECT MERGE(interval), MERGE(interval) FROM features", + dialect=GIQLDialect, + ) + + # Act & Assert + with pytest.raises(ValueError, match="Multiple MERGE"): + transformer.transform(ast) + + def test_transform_should_preserve_where_when_merge_has_where(self): + """Test WHERE clause is preserved in the clustered subquery. + + Given: + A SELECT with MERGE and a WHERE clause + When: + transform is called + Then: + It should preserve the WHERE clause in the clustered subquery + """ + # Act + sql = _transpile_with_transformer( + "SELECT MERGE(interval) FROM features WHERE score > 10", + MergeTransformer, + ) + + # Assert + assert "score > 10" in sql + + def test_transform_should_recurse_when_merge_inside_cte(self): + """Test MERGE inside a CTE subquery is recursively transformed. + + Given: + A SELECT with MERGE inside a CTE subquery + When: + transform is called + Then: + It should recursively transform the CTE subquery + """ + # Act + sql = _transpile_with_transformer( + "WITH m AS (SELECT MERGE(interval) FROM features) SELECT * FROM m", + MergeTransformer, + ) + + # Assert + upper = sql.upper() + assert "GROUP BY" in upper + assert "MIN(" in upper + assert "MAX(" in upper + + +# =========================================================================== +# TestRasterizeTransformer +# =========================================================================== + + +class TestRasterizeTransformer: + """Tests for RasterizeTransformer.transform via transpile().""" + + # ------------------------------------------------------------------ + # Instantiation + # ------------------------------------------------------------------ + + def test___init___should_store_tables_reference(self): + """Test RasterizeTransformer stores its tables reference. + + Given: + A Tables container with registered tables + When: + RasterizeTransformer is instantiated + Then: + It should store the tables reference + """ + # Arrange + tables = Tables() + tables.register("features", Table("features")) + + # Act + transformer = RasterizeTransformer(tables) + + # Assert + assert transformer.tables is tables + + # ------------------------------------------------------------------ + # Basic transpilation + # ------------------------------------------------------------------ + + def test_transform_should_produce_expected_sql_structure_when_basic_count(self): + """Test basic RASTERIZE produces correct SQL structure. + + Given: + A basic RASTERIZE query with count (default stat) + When: + Transpiled + Then: + It should produce SQL with __giql_bins CTE, GENERATE_SERIES, + LEFT JOIN, GROUP BY, COUNT, and ORDER BY + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "GROUP BY" in upper + assert "COUNT" in upper + assert "ORDER BY" in upper + + def test_transform_should_return_unchanged_when_no_rasterize_expression(self): + """Test non-RASTERIZE query passes through unchanged. + + Given: + A query with no RASTERIZE expression + When: + Transformed by RasterizeTransformer + Then: + It should return the query unchanged + """ + # Arrange + tables = Tables() + tables.register("features", Table("features")) + transformer = RasterizeTransformer(tables) + ast = parse_one("SELECT * FROM features", dialect=GIQLDialect) + + # Act + result = transformer.transform(ast) + + # Assert + assert result is ast + + # ------------------------------------------------------------------ + # Default alias + # ------------------------------------------------------------------ + + def test_transform_should_use_value_alias_when_no_explicit_alias(self): + """Test bare RASTERIZE gets default 'value' alias. + + Given: + A RASTERIZE query without an explicit AS alias + When: + Transpiled + Then: + It should alias the aggregate as "value" + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + assert "AS value" in sql + + def test_transform_should_use_explicit_alias_when_alias_provided(self): + """Test explicit AS alias overrides default. + + Given: + A RASTERIZE query with explicit AS alias + When: + Transpiled + Then: + It should use the explicit alias, not "value" + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) AS depth FROM features", + tables=["features"], + ) + + # Assert + assert "AS depth" in sql + assert "AS value" not in sql + + # ------------------------------------------------------------------ + # WHERE clause semantics + # ------------------------------------------------------------------ + + def test_transform_should_move_where_to_join_on_when_where_present(self): + """Test WHERE migrates into LEFT JOIN ON clause. + + Given: + A RASTERIZE query with a WHERE clause + When: + Transpiled + Then: + It should move the WHERE condition into the LEFT JOIN ON clause, + not the outer WHERE + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "ON" in upper + assert "SCORE > 10" in upper + # The condition should be in the ON clause (between LEFT JOIN and GROUP BY) + after_join = sql.split("LEFT JOIN")[1] + on_clause = after_join.split("GROUP BY")[0] + assert "score > 10" in on_clause + + def test_transform_should_qualify_columns_in_on_when_where_present(self): + """Test WHERE column references are qualified with source table in ON. + + Given: + A RASTERIZE query with a WHERE clause + When: + Transpiled + Then: + It should qualify unqualified column references in the JOIN ON + with the source table + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + after_join = sql.split("LEFT JOIN")[1] + on_clause = after_join.split("GROUP BY")[0] + assert "features.score" in on_clause + + def test_transform_should_apply_where_to_chroms_subquery_when_where_present(self): + """Test WHERE is also applied to the chroms subquery. + + Given: + A RASTERIZE query with a WHERE clause + When: + Transpiled + Then: + It should also apply the WHERE to the chroms subquery with + table-qualified columns + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features WHERE score > 10", + tables=["features"], + ) + + # Assert + # The chroms subquery is inside the CTE, before the outer SELECT + cte_part = sql.split(") SELECT")[0] + assert "features.score > 10" in cte_part + + # ------------------------------------------------------------------ + # Column mapping + # ------------------------------------------------------------------ + + def test_transform_should_use_custom_column_names_when_mapping_provided(self): + """Test custom column names are used throughout. + + Given: + A RASTERIZE query with custom column mappings + (chromosome, start_pos, end_pos) + When: + Transpiled + Then: + It should use the mapped column names throughout + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM peaks", + tables=[ + Table( + "peaks", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + ], + ) + + # Assert + assert "chromosome" in sql + assert "start_pos" in sql + assert "end_pos" in sql + + # ------------------------------------------------------------------ + # Additional SELECT columns + # ------------------------------------------------------------------ + + def test_transform_should_include_extra_columns_when_additional_select_columns(self): + """Test extra SELECT columns pass through alongside RASTERIZE. + + Given: + A RASTERIZE query with additional columns alongside RASTERIZE + When: + Transpiled + Then: + It should include the extra columns in the output + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 500) AS cov, name FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "COV" in upper + assert "NAME" in upper + assert "COUNT" in upper + + # ------------------------------------------------------------------ + # Table alias + # ------------------------------------------------------------------ + + def test_transform_should_use_alias_as_source_when_table_has_alias(self): + """Test table alias is used as source reference in JOIN. + + Given: + A RASTERIZE query with a table alias (FROM features f) + When: + Transpiled + Then: + It should use the alias as the source reference in JOIN + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features f", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + + # ------------------------------------------------------------------ + # Resolution + # ------------------------------------------------------------------ + + def test_transform_should_propagate_resolution_when_resolution_provided(self): + """Test resolution value propagates to generate_series and bin width. + + Given: + A RASTERIZE query with resolution=500 + When: + Transpiled + Then: + It should use 500 as the step in generate_series and bin width + """ + # Act + sql = transpile( + "SELECT RASTERIZE(interval, 500) FROM features", + tables=["features"], + ) + + # Assert + assert "500" in sql + + # ------------------------------------------------------------------ + # CTE nesting + # ------------------------------------------------------------------ + + def test_transform_should_transform_rasterize_when_rasterize_inside_cte(self): + """Test RASTERIZE inside a WITH clause is transformed correctly. + + Given: + A RASTERIZE expression inside a WITH clause + When: + Transpiled + Then: + It should correctly transform the CTE containing RASTERIZE + """ + # Act + sql = transpile( + "WITH cov AS (SELECT RASTERIZE(interval, 1000) FROM features) " + "SELECT * FROM cov", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "COUNT" in upper + + # ------------------------------------------------------------------ + # Error handling + # ------------------------------------------------------------------ + + def test_transform_should_raise_when_multiple_rasterize_expressions(self): + """Test multiple RASTERIZE expressions raise error. + + Given: + A query with two RASTERIZE expressions + When: + Transpiled + Then: + It should raise ValueError matching "Multiple RASTERIZE" + """ + # Act & Assert + with pytest.raises(ValueError, match="Multiple RASTERIZE"): + transpile( + "SELECT RASTERIZE(interval, 1000), RASTERIZE(interval, 500) FROM features", + tables=["features"], + ) + + def test_transform_should_raise_when_from_is_subquery(self): + """Test subquery in FROM raises a descriptive error. + + Given: + A RASTERIZE query whose FROM clause is an inline subquery + When: + Transpiled + Then: + It should raise ValueError matching "FROM clause" + """ + # Act & Assert + with pytest.raises(ValueError, match="FROM clause"): + transpile( + "SELECT RASTERIZE(interval, 1000) " + "FROM (SELECT * FROM features) AS sub", + tables=["features"], + ) + + def test_transform_should_raise_when_resolution_is_negative(self): + """Test negative resolution raises descriptive error. + + Given: + A RASTERIZE query with resolution = -1 + When: + Transpiled + Then: + It should raise ValueError matching "positive" + """ + # Act & Assert + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT RASTERIZE(interval, -1) FROM features", + tables=["features"], + ) + + def test_transform_should_raise_when_resolution_is_zero(self): + """Test zero resolution raises descriptive error. + + Given: + A RASTERIZE query with resolution = 0 + When: + Transpiled + Then: + It should raise ValueError matching "positive" + """ + # Act & Assert + with pytest.raises(ValueError, match="positive"): + transpile( + "SELECT RASTERIZE(interval, 0) FROM features", + tables=["features"], + ) + + # ------------------------------------------------------------------ + # Functional / DuckDB end-to-end + # ------------------------------------------------------------------ + + def test_transform_should_produce_bins_when_basic_count(self, to_df): + """Test count correctness with two intervals in one bin. + + Given: + A DuckDB table with two intervals in the same 1000bp bin + When: + RASTERIZE count is transpiled and executed + Then: + It should return exactly one bin with count=2 + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 300, 400" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert len(df) == 1 + assert set(df["start"].tolist()) == {0} + row = df[df["start"] == 0].iloc[0] + assert row["value"] == 2 + + def test_transform_should_produce_zero_coverage_bins_when_gaps_exist(self, to_df): + """Test zero-coverage bins are present via LEFT JOIN. + + Given: + A DuckDB table with intervals in bins [0,1000) and [2000,3000) + but none in bin [1000,2000), and RASTERIZE resolution=1000 + When: + RASTERIZE count is transpiled and executed + Then: + All three bins should be returned and the middle bin should + report value=0 + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 2500, 2600" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert len(df) == 3 + assert set(df["start"].tolist()) == {0, 1000, 2000} + assert df[df["start"] == 0].iloc[0]["value"] == 1 + assert df[df["start"] == 1000].iloc[0]["value"] == 0 + assert df[df["start"] == 2000].iloc[0]["value"] == 1 + + def test_transform_should_omit_trailing_bin_when_end_on_boundary(self, to_df): + """Test no spurious trailing bin when MAX(end) is on a bin boundary. + + Given: + An interval at chr1:100-1000 with resolution=1000 — MAX(end) + lands exactly on a bin boundary + When: + RASTERIZE is transpiled and executed + Then: + Exactly one bin [0,1000) should be returned with value=1 + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 1000 AS \"end\"" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert len(df) == 1 + assert df.iloc[0]["start"] == 0 + assert df.iloc[0]["value"] == 1 + + def test_transform_should_return_zero_when_bin_has_no_matching_rows(self, to_df): + """Test bins with no matching source rows return value=0. + + Given: + A DuckDB table with intervals at chr1:100-200 and chr1:2500-2600 + and RASTERIZE resolution=500 (bins [0,500), [500,1000), ..., + [2500,3000)) + When: + RASTERIZE count is transpiled and executed + Then: + Bins [500,1000), [1000,1500), [1500,2000), [2000,2500) should + all report value=0 + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 500) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\" " + "UNION ALL SELECT 'chr1', 2500, 2600" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + empty_bin_starts = {500, 1000, 1500, 2000} + for bin_start in empty_bin_starts: + value = df[df["start"] == bin_start].iloc[0]["value"] + assert value == 0, ( + f"bin [{bin_start},{bin_start + 500}) expected 0, got {value}" + ) + + def test_transform_should_preserve_user_ctes_when_rasterize_wraps_them(self, to_df): + """Test user-defined CTEs are preserved when RASTERIZE wraps them. + + Given: + A query with a user-defined CTE (selected) that pre-filters + the source, followed by SELECT RASTERIZE(...) FROM selected + When: + RASTERIZE is transpiled and executed + Then: + The user CTE should be preserved alongside __giql_bins and + the query should execute without "table not found" errors + """ + # Arrange + giql_sql = transpile( + "WITH selected AS (SELECT chrom, start, \"end\" FROM features WHERE score > 50) " + "SELECT RASTERIZE(interval, 1000) FROM selected", + tables=["features", "selected"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\", 80 AS score " + "UNION ALL SELECT 'chr1', 1100, 1200, 10 " + "UNION ALL SELECT 'chr1', 2100, 2200, 90" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert set(df["start"].tolist()) == {0, 1000, 2000} + assert df[df["start"] == 1000].iloc[0]["value"] == 0 + + def test_transform_should_resolve_alias_when_where_uses_table_alias(self, to_df): + """Test alias-qualified WHERE resolves in chroms subquery. + + Given: + A FROM clause with a table alias (features f) and a WHERE + qualifying a column by that alias (f.score > 10) + When: + RASTERIZE is transpiled and executed + Then: + The query should run without binder errors and produce all + three bins with WHERE-filtering applied + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features f WHERE f.score > 10", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\", 50 AS score " + "UNION ALL SELECT 'chr1', 1100, 1200, 5 " + "UNION ALL SELECT 'chr1', 2100, 2200, 80" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert + assert len(df) == 3 + assert set(df["start"].tolist()) == {0, 1000, 2000} + + def test_transform_should_preserve_zero_bins_when_where_in_on(self, to_df): + """Test WHERE in ON preserves bins without matching intervals. + + Given: + A DuckDB table with high-scoring intervals in bin [0,1000) and + bin [2000,3000), plus a low-scoring interval in bin [1000,2000) + When: + RASTERIZE count with WHERE score > 50 is transpiled and executed + Then: + All three bins should be present (the WHERE is in the ON clause + so bins are not dropped even when no source rows match) + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features WHERE score > 50", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 100 AS start, 200 AS \"end\", 100 AS score " + "UNION ALL SELECT 'chr1', 1500, 1600, 10 " + "UNION ALL SELECT 'chr1', 2100, 2200, 80" + ) + + # Act + df = to_df(conn.execute(giql_sql)) + conn.close() + + # Assert — all three bins are present (not filtered by WHERE) + assert len(df) == 3 + assert set(df["start"].tolist()) == {0, 1000, 2000} + + def test_transform_should_count_interval_in_each_overlapped_bin_when_interval_spans_bins( + self, to_df + ): + """Test bedtools-coverage convention: an interval is counted in every bin it overlaps. + + Given: + A DuckDB table with one interval [500, 2500) that spans the + three adjacent 1000bp bins [0, 1000), [1000, 2000), [2000, 3000) + When: + RASTERIZE count is transpiled and executed + Then: + The interval should be counted once in each of the three bins, + matching `bedtools coverage` semantics — totals do not conserve + """ + # Arrange + giql_sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + conn = duckdb.connect(":memory:") + conn.execute( + "CREATE TABLE features AS " + "SELECT 'chr1' AS chrom, 500 AS start, 2500 AS \"end\"" + ) + + # Act + df = to_df(conn.execute(giql_sql)).sort_values("start").reset_index(drop=True) + conn.close() + + # Assert + assert df["start"].tolist() == [0, 1000, 2000] + assert df["value"].tolist() == [1, 1, 1] + + # ------------------------------------------------------------------ + # Property-based transpilation + # ------------------------------------------------------------------ + + @given(resolution=st.integers(min_value=1, max_value=10_000_000)) + @settings(max_examples=50, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_transform_should_contain_structural_elements_when_varying_resolution( + self, resolution + ): + """Test transpiled SQL always contains required structural elements. + + Given: + Any valid resolution (1-10M) + When: + Transpiled via transpile() + Then: + The output SQL should always contain __GIQL_BINS, + GENERATE_SERIES, LEFT JOIN, GROUP BY, COUNT, ORDER BY, + and the resolution value as the bin step + """ + # Act + sql = transpile( + f"SELECT RASTERIZE(interval, {resolution}) FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "__GIQL_BINS" in upper + assert "GENERATE_SERIES" in upper + assert "LEFT JOIN" in upper + assert "GROUP BY" in upper + assert "COUNT" in upper + assert "ORDER BY" in upper + assert str(resolution) in sql diff --git a/tests/unit/test_transpile.py b/tests/unit/test_transpile.py new file mode 100644 index 0000000..5600cdf --- /dev/null +++ b/tests/unit/test_transpile.py @@ -0,0 +1,425 @@ +"""Unit tests for the transpile() function. + +Tests covering all public API behavior of giql.transpile as a black box: +GIQL string in, SQL string out. +""" + +import pytest + +from giql import Table +from giql import transpile + + +class TestTranspile: + """Tests for transpile() public API.""" + + # ── Basic transpilation ────────────────────────────────────────── + + def test_transpile_should_passthrough_plain_sql_unchanged(self): + """Test that plain SQL without GIQL extensions passes through. + + Given: + A plain SQL query with no GIQL extensions + When: + transpile is called + Then: + It should return an equivalent SQL string unchanged + """ + # Arrange / Act + sql = transpile("SELECT id, name FROM features") + + # Assert + upper = sql.upper() + assert "SELECT" in upper + assert "FEATURES" in upper + assert "ID" in upper + + def test_transpile_should_emit_correct_sql_for_intersects_predicate(self): + """Test INTERSECTS predicate expands to range comparisons. + + Given: + A query with an INTERSECTS predicate and a tables list + When: + transpile is called + Then: + It should return SQL that contains expanded range comparison predicates + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:1000-2000'", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "CHR1" in upper + assert "1000" in sql + assert "2000" in sql + # Range overlap requires both start/end comparisons + assert "START" in upper or "END" in upper + + def test_transpile_should_emit_correct_sql_for_contains_predicate(self): + """Test CONTAINS predicate produces containment SQL. + + Given: + A query with a CONTAINS predicate + When: + transpile is called + Then: + It should return SQL that contains containment predicates + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval CONTAINS 'chr1:1500'", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "SELECT" in upper + assert "1500" in sql + + def test_transpile_should_emit_correct_sql_for_within_predicate(self): + """Test WITHIN predicate produces within SQL. + + Given: + A query with a WITHIN predicate + When: + transpile is called + Then: + It should return SQL that contains within predicates + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval WITHIN 'chr1:1000-2000'", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "SELECT" in upper + assert "1000" in sql + assert "2000" in sql + + # ── CLUSTER transpilation ──────────────────────────────────────── + + def test_transpile_should_emit_window_functions_for_cluster(self): + """Test CLUSTER expands to LAG and SUM window functions. + + Given: + A query with CLUSTER(interval) and tables=["features"] + When: + transpile is called + Then: + It should return SQL that contains LAG and SUM window functions in a subquery + """ + # Arrange / Act + sql = transpile( + "SELECT *, CLUSTER(interval) AS cluster_id FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "LAG" in upper + assert "SUM" in upper + + def test_transpile_should_include_distance_offset_for_cluster_with_distance(self): + """Test CLUSTER with distance includes the offset in LAG. + + Given: + A query with CLUSTER(interval, 1000) + When: + transpile is called + Then: + It should return SQL that includes a distance offset in the LAG expression + """ + # Arrange / Act + sql = transpile( + "SELECT *, CLUSTER(interval, 1000) AS cluster_id FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "LAG" in upper + assert "1000" in sql + + # ── MERGE transpilation ────────────────────────────────────────── + + def test_transpile_should_emit_group_by_aggregation_for_merge(self): + """Test MERGE expands to CTE with GROUP BY and MIN/MAX. + + Given: + A query with MERGE(interval) and tables=["features"] + When: + transpile is called + Then: + It should return SQL that contains a CLUSTER CTE with GROUP BY and MIN/MAX aggregation + """ + # Arrange / Act + sql = transpile( + "SELECT MERGE(interval) FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "MIN" in upper + assert "MAX" in upper + assert "GROUP BY" in upper + + # ── RASTERIZE transpilation ─────────────────────────────────────── + + def test_transpile_should_emit_bins_cte_for_rasterize(self): + """Test RASTERIZE expands to bins CTE with LEFT JOIN and COUNT. + + Given: + A query with RASTERIZE(interval, 1000) and tables=["features"] + When: + transpile is called + Then: + It should return SQL that contains a bins CTE, LEFT JOIN, COUNT, GROUP BY, and ORDER BY + """ + # Arrange / Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert "LEFT JOIN" in upper or "LEFT OUTER JOIN" in upper + assert "COUNT" in upper + assert "GROUP BY" in upper + assert "ORDER BY" in upper + assert "1000" in sql + + def test_transpile_should_use_custom_alias_for_rasterize_when_provided(self): + """Test RASTERIZE with AS cov aliases the aggregate column as "cov". + + Given: + A query with RASTERIZE(interval, 1000) AS cov + When: + transpile is called + Then: + It should alias the aggregate column in the returned SQL as "cov" + """ + # Arrange / Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) AS cov FROM features", + tables=["features"], + ) + + # Assert + assert "cov" in sql.lower() + + def test_transpile_should_use_default_value_alias_for_bare_rasterize(self): + """Test bare RASTERIZE aliases the aggregate column as "value". + + Given: + A query with bare RASTERIZE(interval, 1000) (no alias) + When: + transpile is called + Then: + It should alias the aggregate column in the returned SQL as "value" + """ + # Arrange / Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features", + tables=["features"], + ) + + # Assert + assert "value" in sql.lower() + + def test_transpile_should_fold_where_into_join_on_for_rasterize(self): + """Test RASTERIZE folds WHERE into the JOIN ON condition. + + Given: + A query with RASTERIZE and a WHERE clause + When: + transpile is called + Then: + It should place the WHERE condition in the JOIN ON condition rather than as a standalone WHERE + """ + # Arrange / Act + sql = transpile( + "SELECT RASTERIZE(interval, 1000) FROM features WHERE chrom = 'chr1'", + tables=["features"], + ) + + # Assert + upper = sql.upper() + # The WHERE should be folded into the JOIN ON condition + assert "JOIN" in upper + assert "CHR1" in upper + + # ── DISTANCE transpilation ─────────────────────────────────────── + + def test_transpile_should_emit_case_expression_for_distance(self): + """Test DISTANCE expands to a CASE expression. + + Given: + A query with DISTANCE(a.interval, b.interval) and two tables + When: + transpile is called + Then: + It should return SQL that contains a CASE expression for computing distance + """ + # Arrange / Act + sql = transpile( + "SELECT DISTANCE(a.interval, b.interval) FROM features a, genes b", + tables=["features", "genes"], + ) + + # Assert + upper = sql.upper() + assert "CASE" in upper + + # ── NEAREST transpilation ──────────────────────────────────────── + + def test_transpile_should_emit_lateral_subquery_with_limit_for_nearest(self): + """Test NEAREST expands to a LATERAL subquery with a LIMIT. + + Given: + A query with NEAREST in a LATERAL join and two tables + When: + transpile is called + Then: + It should return SQL that contains a LATERAL subquery with a LIMIT clause + """ + # Arrange / Act + sql = transpile( + """ + SELECT * + FROM peaks + CROSS JOIN LATERAL NEAREST(genes, reference=peaks.interval, k=3) + """, + tables=["peaks", "genes"], + ) + + # Assert + upper = sql.upper() + assert "LATERAL" in upper + assert "LIMIT" in upper + + # ── Table configuration ────────────────────────────────────────── + + def test_transpile_should_register_string_tables_with_default_columns(self): + """Test string-list tables use default column mappings. + + Given: + tables parameter as a list of strings + When: + transpile is called + Then: + It should register tables with default column mappings (chrom, start, end) + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=["features"], + ) + + # Assert + upper = sql.upper() + assert '"CHROM"' in upper or "CHROM" in upper + assert '"START"' in upper or "START" in upper + assert '"END"' in upper or "END" in upper + + def test_transpile_should_honor_custom_table_object_column_names(self): + """Test Table objects with custom column names propagate into SQL. + + Given: + tables parameter as a list of Table objects with custom column names + When: + transpile is called + Then: + It should generate SQL that uses those custom column names + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=[ + Table( + "features", + genomic_col="interval", + chrom_col="chromosome", + start_col="start_pos", + end_col="end_pos", + ) + ], + ) + + # Assert + assert "chromosome" in sql or "CHROMOSOME" in sql.upper() + assert "start_pos" in sql or "START_POS" in sql.upper() + assert "end_pos" in sql or "END_POS" in sql.upper() + + def test_transpile_should_use_default_columns_when_tables_is_none(self): + """Test None tables parameter still uses default column names. + + Given: + tables parameter is None + When: + transpile is called + Then: + It should still use default column names (chrom, start, end) + """ + # Arrange / Act + sql = transpile( + "SELECT * FROM features WHERE interval INTERSECTS 'chr1:100-200'", + tables=None, + ) + + # Assert + upper = sql.upper() + assert "SELECT" in upper + assert "CHROM" in upper + + def test_transpile_should_register_mixed_strings_and_table_objects(self): + """Test mixing strings and Table objects in tables parameter. + + Given: + tables parameter mixes strings and Table objects + When: + transpile is called + Then: + It should correctly register both and produce valid SQL + """ + # Arrange / Act + sql = transpile( + """ + SELECT a.*, b.* + FROM peaks a + JOIN genes b ON a.interval INTERSECTS b.region + """, + tables=[ + "peaks", + Table("genes", genomic_col="region", chrom_col="seqname"), + ], + ) + + # Assert + upper = sql.upper() + assert "PEAKS" in upper + assert "GENES" in upper + assert "SEQNAME" in upper + + # ── Error handling ─────────────────────────────────────────────── + + def test_transpile_should_raise_value_error_for_invalid_query(self): + """Test unparseable query raises ValueError with Parse error message. + + Given: + An invalid/unparseable query string + When: + transpile is called + Then: + It should raise ValueError with a message containing "Parse error" + """ + # Arrange / Act / Assert + with pytest.raises(ValueError, match="Parse error"): + transpile("SELECT * FORM features") +