Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b332c5c
[DataLoader] Add full DataFusion SQLGlot dialect for SQL transpilation
robreeves Mar 13, 2026
f796b3b
[DataLoader] Generalize SQL translator to accept any source dialect
robreeves Mar 13, 2026
0c9222f
[DataLoader] Move translate_to_datafusion into datafusion_dialect module
robreeves Mar 13, 2026
a50df1b
[DataLoader] Rename module to datafusion_sql and function to to_dataf…
robreeves Mar 13, 2026
e12bb4d
[DataLoader] Handle datafusion as noop dialect and improve error mess…
robreeves Mar 13, 2026
6106a1c
[DataLoader] Assert full output queries in dialect tests
robreeves Mar 13, 2026
b39308d
[DataLoader] Use to_datafusion_sql in tests and deduplicate
robreeves Mar 13, 2026
91456b9
[DataLoader] Consolidate tests: merge type mappings, single e2e execu…
robreeves Mar 13, 2026
af32561
[DataLoader] Consolidate all transpilation tests into single parametr…
robreeves Mar 13, 2026
88b0fc5
[DataLoader] Inline SUPPORTED_SOURCE_DIALECTS into error message
robreeves Mar 13, 2026
a4b131c
[DataLoader] Move identity round-trip into test_transpilation paramet…
robreeves Mar 13, 2026
f1a9ea5
[DataLoader] Remove SPARK constant, inline string literals in tests
robreeves Mar 13, 2026
0305583
[DataLoader] Bump sqlglot minimum version to 29.0.0
robreeves Mar 13, 2026
451aba8
[DataLoader] Wrap sqlglot transpile errors in ValueError
robreeves Mar 16, 2026
7bf4888
[DataLoader] Fix median/percentile mappings, add approx_percentile_cont
robreeves Mar 16, 2026
ee7e995
[DataLoader] Add docstring to to_datafusion_sql
robreeves Mar 16, 2026
0e55872
[DataLoader] Add UDF execution test for transpiler
robreeves Mar 16, 2026
65d0367
[DataLoader] Fix line length lint error in test
robreeves Mar 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integrations/python/dataloader/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ readme = "README.md"
requires-python = ">=3.12"
license = {text = "BSD-2-Clause"}
keywords = ["openhouse", "data-loader", "lakehouse", "iceberg", "datafusion"]
dependencies = ["datafusion==51.0.0", "pyiceberg @ git+https://github.com/sumedhsakdeo/iceberg-python@75ba28bfc6d8bbeac398357c6db80327632a2dc8", "requests>=2.31.0", "tenacity>=8.0.0"]
dependencies = ["datafusion==51.0.0", "pyiceberg @ git+https://github.com/sumedhsakdeo/iceberg-python@75ba28bfc6d8bbeac398357c6db80327632a2dc8", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"]

[project.optional-dependencies]
dev = ["responses>=0.25.0", "ruff>=0.9.0", "pytest>=8.0.0", "twine>=6.0.0", "mypy>=1.14.0", "types-requests>=2.31.0"]
Expand All @@ -32,7 +32,7 @@ warn_unused_configs = true
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = ["datafusion.*", "pyiceberg.*", "pyarrow.*", "tenacity.*"]
module = ["datafusion.*", "pyiceberg.*", "pyarrow.*", "sqlglot.*", "tenacity.*"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import sqlglot
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, NormalizationStrategy, rename_func
from sqlglot.generator import Generator as _Generator
from sqlglot.parser import Parser as _Parser
from sqlglot.tokens import Tokenizer as _Tokenizer
from sqlglot.tokens import TokenType


class DataFusion(Dialect):
NORMALIZE_FUNCTIONS: bool | str = "lower"
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
NULL_ORDERING = "nulls_are_last"
INDEX_OFFSET = 0
TYPED_DIVISION = True
SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True
SUPPORTS_ORDER_BY_ALL = True

class Tokenizer(_Tokenizer):
IDENTIFIERS = ['"']
KEYWORDS = {
**_Tokenizer.KEYWORDS,
"UTF8": TokenType.TEXT,
}

class Parser(_Parser):
FUNCTIONS = {
**_Parser.FUNCTIONS,
"MAKE_ARRAY": exp.Array.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
"BOOL_AND": exp.LogicalAnd.from_arg_list,
"BOOL_OR": exp.LogicalOr.from_arg_list,
"BIT_AND": exp.BitwiseAndAgg.from_arg_list,
"BIT_OR": exp.BitwiseOrAgg.from_arg_list,
"BIT_XOR": exp.BitwiseXorAgg.from_arg_list,
"VAR_POP": exp.VariancePop.from_arg_list,
"VAR_SAMPLE": exp.Variance.from_arg_list,
"STDDEV_POP": exp.StddevPop.from_arg_list,
"COVAR_POP": exp.CovarPop.from_arg_list,
"COVAR_SAMP": exp.CovarSamp.from_arg_list,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE_CONT": exp.ApproxQuantile.from_arg_list,
"STRING_AGG": exp.GroupConcat.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
}

class Generator(_Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False

TYPE_MAPPING = {
**_Generator.TYPE_MAPPING,
exp.DataType.Type.CHAR: "VARCHAR",
exp.DataType.Type.NCHAR: "VARCHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.BINARY: "BYTEA",
exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMPTZ",
exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP",
}

TRANSFORMS = {
**_Generator.TRANSFORMS,
# Array
exp.Array: rename_func("make_array"),
exp.ArraySize: rename_func("cardinality"),
exp.SortArray: rename_func("array_sort"),
exp.ArrayContains: rename_func("array_has"),
# Aggregate
exp.LogicalAnd: rename_func("bool_and"),
exp.LogicalOr: rename_func("bool_or"),
exp.BitwiseAndAgg: rename_func("bit_and"),
exp.BitwiseOrAgg: rename_func("bit_or"),
exp.BitwiseXorAgg: rename_func("bit_xor"),
exp.VariancePop: rename_func("var_pop"),
exp.Variance: rename_func("var_sample"),
exp.StddevPop: rename_func("stddev_pop"),
exp.CovarPop: rename_func("covar_pop"),
exp.CovarSamp: rename_func("covar_samp"),
exp.ApproxDistinct: rename_func("approx_distinct"),
exp.ApproxQuantile: rename_func("approx_percentile_cont"),
exp.GroupConcat: rename_func("string_agg"),
# Datetime
exp.CurrentTimestamp: lambda *_: "now()",
}


def to_datafusion_sql(sql: str, source_dialect: str) -> str:
Comment thread
robreeves marked this conversation as resolved.
"""Transpile a single SQL statement to the DataFusion dialect.

Args:
sql: SQL statement in the source dialect.
source_dialect: sqlglot dialect name (e.g. "spark", "postgres"). Use
"datafusion" to skip transpilation and return the input unchanged.

Raises:
ValueError: If the dialect is unsupported, the SQL is invalid, or the
input contains more than one statement.
"""
if source_dialect not in Dialect.classes:
raise ValueError(
f"Unsupported source dialect '{source_dialect}'. Supported dialects: {', '.join(sorted(Dialect.classes))}"
)
if source_dialect == "datafusion":
return sql
try:
statements = sqlglot.transpile(sql, read=source_dialect, write="datafusion")
except sqlglot.errors.SqlglotError as e:
raise ValueError(f"Failed to transpile SQL from '{source_dialect}' to DataFusion: {e}") from e
if len(statements) != 1:
raise ValueError(f"Expected exactly one SQL statement, got {len(statements)}: {statements}")
return statements[0]
134 changes: 134 additions & 0 deletions integrations/python/dataloader/tests/test_datafusion_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

import datafusion
import pyarrow as pa
import pytest

from openhouse.dataloader.datafusion_sql import to_datafusion_sql

# ---------------------------------------------------------------------------
# Transpilation tests
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
"sql, dialect, expected",
[
# Spark → DataFusion
Comment thread
robreeves marked this conversation as resolved.
("SELECT `col1`, `col2` FROM `my_table`", "spark", 'SELECT "col1", "col2" FROM "my_table"'),
("SELECT SIZE(arr) FROM t", "spark", "SELECT cardinality(arr) FROM t"),
("SELECT ARRAY(1, 2, 3)", "spark", "SELECT make_array(1, 2, 3)"),
("SELECT UPPER(name) FROM t", "spark", "SELECT upper(name) FROM t"),
("SELECT my_udf(col1, col2) FROM t", "spark", "SELECT my_udf(col1, col2) FROM t"),
("SELECT IF(x > 0, 'pos', 'neg') FROM t", "spark", "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t"),
(
"SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t",
"spark",
"SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t",
),
(
"SELECT * FROM (SELECT id, name FROM t WHERE id > 10) sub WHERE sub.name IS NOT NULL",
"spark",
"SELECT * FROM (SELECT id, name FROM t WHERE id > 10) AS sub WHERE NOT sub.name IS NULL",
),
("SELECT 'hello world' AS greeting", "spark", "SELECT 'hello world' AS greeting"),
("SELECT CURRENT_TIMESTAMP()", "spark", "SELECT now()"),
("SELECT CAST(x AS BINARY)", "spark", "SELECT TRY_CAST(x AS BYTEA)"),
# MySQL → DataFusion
("SELECT CAST(x AS CHAR)", "mysql", "SELECT CAST(x AS VARCHAR)"),
("SELECT CAST(x AS DATETIME)", "mysql", "SELECT CAST(x AS TIMESTAMP)"),
# Postgres → DataFusion
("SELECT CAST(x AS TEXT)", "postgres", "SELECT CAST(x AS VARCHAR)"),
# DataFusion → DataFusion (noop)
(
"SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5",
"datafusion",
"SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5",
),
],
)
def test_transpilation(sql: str, dialect: str, expected: str) -> None:
assert to_datafusion_sql(sql, dialect) == expected


# ---------------------------------------------------------------------------
# to_datafusion_sql error handling and edge cases
# ---------------------------------------------------------------------------


class TestTranslatorEdgeCases:
def test_multi_statement_raises(self) -> None:
with pytest.raises(ValueError, match="Expected exactly one"):
to_datafusion_sql("SELECT 1; SELECT 2", "spark")

def test_unsupported_dialect_raises(self) -> None:
with pytest.raises(ValueError, match="Unsupported source dialect 'nosuchdialect'"):
to_datafusion_sql("SELECT 1", "nosuchdialect")

def test_syntax_error_raises(self) -> None:
with pytest.raises(ValueError, match="Failed to transpile SQL from 'spark' to DataFusion"):
to_datafusion_sql("SELECT * FROM", "spark")

def test_datafusion_dialect_is_noop(self) -> None:
sql = "SELECT make_array(1, 2, 3)"
assert to_datafusion_sql(sql, "datafusion") is sql


# ---------------------------------------------------------------------------
# DataFusion execution tests (requires datafusion package)
# ---------------------------------------------------------------------------


def test_datafusion_execution() -> None:
ctx = datafusion.SessionContext()
translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", "spark")
batch = ctx.sql(translated).collect()[0]
assert batch.column(0)[0].as_py() == 3


def test_datafusion_execution_median() -> None:
ctx = datafusion.SessionContext()
translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", "spark")
assert translated == "SELECT median(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)"
batch = ctx.sql(translated).collect()[0]
assert batch.column(0)[0].as_py() == 3


def test_datafusion_execution_percentile_cont() -> None:
ctx = datafusion.SessionContext()
translated = to_datafusion_sql(
"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)",
"spark",
)
expected = (
"SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY x NULLS FIRST)"
" FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)"
)
assert translated == expected
batch = ctx.sql(translated).collect()[0]
assert batch.column(0)[0].as_py() == 3.0


def test_datafusion_execution_approx_percentile_cont() -> None:
ctx = datafusion.SessionContext()
translated = to_datafusion_sql(
"SELECT PERCENTILE_APPROX(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)",
"spark",
)
assert translated == "SELECT approx_percentile_cont(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)"
batch = ctx.sql(translated).collect()[0]
assert batch.column(0)[0].as_py() == 3


def test_datafusion_execution_udf() -> None:
ctx = datafusion.SessionContext()

def double_it(arr: pa.Array) -> pa.Array:
return pa.array([x * 2 for x in arr.to_pylist()])

ctx.register_udf(datafusion.udf(double_it, [pa.int64()], pa.int64(), "stable", name="double_it"))

translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", "spark")
assert translated == "SELECT double_it(x) FROM (VALUES (5)) AS t(x)"
batch = ctx.sql(translated).collect()[0]
assert batch.column(0)[0].as_py() == 10
11 changes: 11 additions & 0 deletions integrations/python/dataloader/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.