-
Notifications
You must be signed in to change notification settings - Fork 76
[DataLoader] Add DataFusion SQLGlot dialect for SQL transpilation #501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
b332c5c
[DataLoader] Add full DataFusion SQLGlot dialect for SQL transpilation
robreeves f796b3b
[DataLoader] Generalize SQL translator to accept any source dialect
robreeves 0c9222f
[DataLoader] Move translate_to_datafusion into datafusion_dialect module
robreeves a50df1b
[DataLoader] Rename module to datafusion_sql and function to to_dataf…
robreeves e12bb4d
[DataLoader] Handle datafusion as noop dialect and improve error mess…
robreeves 6106a1c
[DataLoader] Assert full output queries in dialect tests
robreeves b39308d
[DataLoader] Use to_datafusion_sql in tests and deduplicate
robreeves 91456b9
[DataLoader] Consolidate tests: merge type mappings, single e2e execu…
robreeves af32561
[DataLoader] Consolidate all transpilation tests into single parametr…
robreeves 88b0fc5
[DataLoader] Inline SUPPORTED_SOURCE_DIALECTS into error message
robreeves a4b131c
[DataLoader] Move identity round-trip into test_transpilation paramet…
robreeves f1a9ea5
[DataLoader] Remove SPARK constant, inline string literals in tests
robreeves 0305583
[DataLoader] Bump sqlglot minimum version to 29.0.0
robreeves 451aba8
[DataLoader] Wrap sqlglot transpile errors in ValueError
robreeves 7bf4888
[DataLoader] Fix median/percentile mappings, add approx_percentile_cont
robreeves ee7e995
[DataLoader] Add docstring to to_datafusion_sql
robreeves 0e55872
[DataLoader] Add UDF execution test for transpiler
robreeves 65d0367
[DataLoader] Fix line length lint error in test
robreeves File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import sqlglot | ||
| from sqlglot import exp | ||
| from sqlglot.dialects.dialect import Dialect, NormalizationStrategy, rename_func | ||
| from sqlglot.generator import Generator as _Generator | ||
| from sqlglot.parser import Parser as _Parser | ||
| from sqlglot.tokens import Tokenizer as _Tokenizer | ||
| from sqlglot.tokens import TokenType | ||
|
|
||
|
|
||
| class DataFusion(Dialect): | ||
| NORMALIZE_FUNCTIONS: bool | str = "lower" | ||
| NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE | ||
| NULL_ORDERING = "nulls_are_last" | ||
| INDEX_OFFSET = 0 | ||
| TYPED_DIVISION = True | ||
| SUPPORTS_USER_DEFINED_TYPES = False | ||
| LOG_BASE_FIRST = True | ||
| SUPPORTS_ORDER_BY_ALL = True | ||
|
|
||
| class Tokenizer(_Tokenizer): | ||
| IDENTIFIERS = ['"'] | ||
| KEYWORDS = { | ||
| **_Tokenizer.KEYWORDS, | ||
| "UTF8": TokenType.TEXT, | ||
| } | ||
|
|
||
| class Parser(_Parser): | ||
| FUNCTIONS = { | ||
| **_Parser.FUNCTIONS, | ||
| "MAKE_ARRAY": exp.Array.from_arg_list, | ||
| "CARDINALITY": exp.ArraySize.from_arg_list, | ||
| "ARRAY_SORT": exp.SortArray.from_arg_list, | ||
| "ARRAY_HAS": exp.ArrayContains.from_arg_list, | ||
| "BOOL_AND": exp.LogicalAnd.from_arg_list, | ||
| "BOOL_OR": exp.LogicalOr.from_arg_list, | ||
| "BIT_AND": exp.BitwiseAndAgg.from_arg_list, | ||
| "BIT_OR": exp.BitwiseOrAgg.from_arg_list, | ||
| "BIT_XOR": exp.BitwiseXorAgg.from_arg_list, | ||
| "VAR_POP": exp.VariancePop.from_arg_list, | ||
| "VAR_SAMPLE": exp.Variance.from_arg_list, | ||
| "STDDEV_POP": exp.StddevPop.from_arg_list, | ||
| "COVAR_POP": exp.CovarPop.from_arg_list, | ||
| "COVAR_SAMP": exp.CovarSamp.from_arg_list, | ||
| "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, | ||
| "APPROX_PERCENTILE_CONT": exp.ApproxQuantile.from_arg_list, | ||
| "STRING_AGG": exp.GroupConcat.from_arg_list, | ||
| "NOW": exp.CurrentTimestamp.from_arg_list, | ||
| } | ||
|
|
||
| class Generator(_Generator): | ||
| JOIN_HINTS = False | ||
| TABLE_HINTS = False | ||
| QUERY_HINTS = False | ||
| NVL2_SUPPORTED = False | ||
| SUPPORTS_CREATE_TABLE_LIKE = False | ||
|
|
||
| TYPE_MAPPING = { | ||
| **_Generator.TYPE_MAPPING, | ||
| exp.DataType.Type.CHAR: "VARCHAR", | ||
| exp.DataType.Type.NCHAR: "VARCHAR", | ||
| exp.DataType.Type.NVARCHAR: "VARCHAR", | ||
| exp.DataType.Type.TEXT: "VARCHAR", | ||
| exp.DataType.Type.BINARY: "BYTEA", | ||
| exp.DataType.Type.VARBINARY: "BYTEA", | ||
| exp.DataType.Type.DATETIME: "TIMESTAMP", | ||
| exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMPTZ", | ||
| exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", | ||
| } | ||
|
|
||
| TRANSFORMS = { | ||
| **_Generator.TRANSFORMS, | ||
| # Array | ||
| exp.Array: rename_func("make_array"), | ||
| exp.ArraySize: rename_func("cardinality"), | ||
| exp.SortArray: rename_func("array_sort"), | ||
| exp.ArrayContains: rename_func("array_has"), | ||
| # Aggregate | ||
| exp.LogicalAnd: rename_func("bool_and"), | ||
| exp.LogicalOr: rename_func("bool_or"), | ||
| exp.BitwiseAndAgg: rename_func("bit_and"), | ||
| exp.BitwiseOrAgg: rename_func("bit_or"), | ||
| exp.BitwiseXorAgg: rename_func("bit_xor"), | ||
| exp.VariancePop: rename_func("var_pop"), | ||
| exp.Variance: rename_func("var_sample"), | ||
| exp.StddevPop: rename_func("stddev_pop"), | ||
| exp.CovarPop: rename_func("covar_pop"), | ||
| exp.CovarSamp: rename_func("covar_samp"), | ||
| exp.ApproxDistinct: rename_func("approx_distinct"), | ||
| exp.ApproxQuantile: rename_func("approx_percentile_cont"), | ||
| exp.GroupConcat: rename_func("string_agg"), | ||
| # Datetime | ||
| exp.CurrentTimestamp: lambda *_: "now()", | ||
| } | ||
|
|
||
|
|
||
| def to_datafusion_sql(sql: str, source_dialect: str) -> str: | ||
| """Transpile a single SQL statement to the DataFusion dialect. | ||
|
|
||
| Args: | ||
| sql: SQL statement in the source dialect. | ||
| source_dialect: sqlglot dialect name (e.g. "spark", "postgres"). Use | ||
| "datafusion" to skip transpilation and return the input unchanged. | ||
|
|
||
| Raises: | ||
| ValueError: If the dialect is unsupported, the SQL is invalid, or the | ||
| input contains more than one statement. | ||
| """ | ||
| if source_dialect not in Dialect.classes: | ||
| raise ValueError( | ||
| f"Unsupported source dialect '{source_dialect}'. Supported dialects: {', '.join(sorted(Dialect.classes))}" | ||
| ) | ||
| if source_dialect == "datafusion": | ||
| return sql | ||
| try: | ||
| statements = sqlglot.transpile(sql, read=source_dialect, write="datafusion") | ||
| except sqlglot.errors.SqlglotError as e: | ||
| raise ValueError(f"Failed to transpile SQL from '{source_dialect}' to DataFusion: {e}") from e | ||
| if len(statements) != 1: | ||
| raise ValueError(f"Expected exactly one SQL statement, got {len(statements)}: {statements}") | ||
| return statements[0] | ||
134 changes: 134 additions & 0 deletions
134
integrations/python/dataloader/tests/test_datafusion_sql.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import datafusion | ||
| import pyarrow as pa | ||
| import pytest | ||
|
|
||
| from openhouse.dataloader.datafusion_sql import to_datafusion_sql | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Transpilation tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "sql, dialect, expected", | ||
| [ | ||
| # Spark → DataFusion | ||
|
robreeves marked this conversation as resolved.
|
||
| ("SELECT `col1`, `col2` FROM `my_table`", "spark", 'SELECT "col1", "col2" FROM "my_table"'), | ||
| ("SELECT SIZE(arr) FROM t", "spark", "SELECT cardinality(arr) FROM t"), | ||
| ("SELECT ARRAY(1, 2, 3)", "spark", "SELECT make_array(1, 2, 3)"), | ||
| ("SELECT UPPER(name) FROM t", "spark", "SELECT upper(name) FROM t"), | ||
| ("SELECT my_udf(col1, col2) FROM t", "spark", "SELECT my_udf(col1, col2) FROM t"), | ||
| ("SELECT IF(x > 0, 'pos', 'neg') FROM t", "spark", "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t"), | ||
| ( | ||
| "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", | ||
| "spark", | ||
| "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", | ||
| ), | ||
| ( | ||
| "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) sub WHERE sub.name IS NOT NULL", | ||
| "spark", | ||
| "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) AS sub WHERE NOT sub.name IS NULL", | ||
| ), | ||
| ("SELECT 'hello world' AS greeting", "spark", "SELECT 'hello world' AS greeting"), | ||
| ("SELECT CURRENT_TIMESTAMP()", "spark", "SELECT now()"), | ||
| ("SELECT CAST(x AS BINARY)", "spark", "SELECT TRY_CAST(x AS BYTEA)"), | ||
| # MySQL → DataFusion | ||
| ("SELECT CAST(x AS CHAR)", "mysql", "SELECT CAST(x AS VARCHAR)"), | ||
| ("SELECT CAST(x AS DATETIME)", "mysql", "SELECT CAST(x AS TIMESTAMP)"), | ||
| # Postgres → DataFusion | ||
| ("SELECT CAST(x AS TEXT)", "postgres", "SELECT CAST(x AS VARCHAR)"), | ||
| # DataFusion → DataFusion (noop) | ||
| ( | ||
| "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", | ||
| "datafusion", | ||
| "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_transpilation(sql: str, dialect: str, expected: str) -> None: | ||
| assert to_datafusion_sql(sql, dialect) == expected | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # to_datafusion_sql error handling and edge cases | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestTranslatorEdgeCases: | ||
| def test_multi_statement_raises(self) -> None: | ||
| with pytest.raises(ValueError, match="Expected exactly one"): | ||
| to_datafusion_sql("SELECT 1; SELECT 2", "spark") | ||
|
|
||
| def test_unsupported_dialect_raises(self) -> None: | ||
| with pytest.raises(ValueError, match="Unsupported source dialect 'nosuchdialect'"): | ||
| to_datafusion_sql("SELECT 1", "nosuchdialect") | ||
|
|
||
| def test_syntax_error_raises(self) -> None: | ||
| with pytest.raises(ValueError, match="Failed to transpile SQL from 'spark' to DataFusion"): | ||
| to_datafusion_sql("SELECT * FROM", "spark") | ||
|
|
||
| def test_datafusion_dialect_is_noop(self) -> None: | ||
| sql = "SELECT make_array(1, 2, 3)" | ||
| assert to_datafusion_sql(sql, "datafusion") is sql | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # DataFusion execution tests (requires datafusion package) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_datafusion_execution() -> None: | ||
| ctx = datafusion.SessionContext() | ||
| translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", "spark") | ||
| batch = ctx.sql(translated).collect()[0] | ||
| assert batch.column(0)[0].as_py() == 3 | ||
|
|
||
|
|
||
| def test_datafusion_execution_median() -> None: | ||
| ctx = datafusion.SessionContext() | ||
| translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", "spark") | ||
| assert translated == "SELECT median(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" | ||
| batch = ctx.sql(translated).collect()[0] | ||
| assert batch.column(0)[0].as_py() == 3 | ||
|
|
||
|
|
||
| def test_datafusion_execution_percentile_cont() -> None: | ||
| ctx = datafusion.SessionContext() | ||
| translated = to_datafusion_sql( | ||
| "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", | ||
| "spark", | ||
| ) | ||
| expected = ( | ||
| "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY x NULLS FIRST)" | ||
| " FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" | ||
| ) | ||
| assert translated == expected | ||
| batch = ctx.sql(translated).collect()[0] | ||
| assert batch.column(0)[0].as_py() == 3.0 | ||
|
|
||
|
|
||
| def test_datafusion_execution_approx_percentile_cont() -> None: | ||
| ctx = datafusion.SessionContext() | ||
| translated = to_datafusion_sql( | ||
| "SELECT PERCENTILE_APPROX(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", | ||
| "spark", | ||
| ) | ||
| assert translated == "SELECT approx_percentile_cont(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" | ||
| batch = ctx.sql(translated).collect()[0] | ||
| assert batch.column(0)[0].as_py() == 3 | ||
|
|
||
|
|
||
| def test_datafusion_execution_udf() -> None: | ||
| ctx = datafusion.SessionContext() | ||
|
|
||
| def double_it(arr: pa.Array) -> pa.Array: | ||
| return pa.array([x * 2 for x in arr.to_pylist()]) | ||
|
|
||
| ctx.register_udf(datafusion.udf(double_it, [pa.int64()], pa.int64(), "stable", name="double_it")) | ||
|
|
||
| translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", "spark") | ||
| assert translated == "SELECT double_it(x) FROM (VALUES (5)) AS t(x)" | ||
| batch = ctx.sql(translated).collect()[0] | ||
| assert batch.column(0)[0].as_py() == 10 | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.