diff --git a/README.md b/README.md index 92fa355416..1544ec8ddc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ![SQLGlot logo](sqlglot.png) -SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [31 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/) / [Trino](https://trino.io/), [Spark](https://spark.apache.org/) / [Databricks](https://www.databricks.com/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically and semantically correct SQL in the targeted dialects. +SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [32 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/) / [Trino](https://trino.io/), [Spark](https://spark.apache.org/) / [Databricks](https://www.databricks.com/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically and semantically correct SQL in the targeted dialects. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks), while being written purely in Python. @@ -586,6 +586,7 @@ x + interval '1' month | BigQuery | Official | | ClickHouse | Official | | Databricks | Official | +| Db2 | Community | | Doris | Community | | Dremio | Community | | Drill | Community | diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 14e5796b0e..1ea521112e 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -69,6 +69,7 @@ class Generator(Generator): "BigQuery", "ClickHouse", "Databricks", + "Db2", "Doris", "Dremio", "Drill", diff --git a/sqlglot/dialects/db2.py b/sqlglot/dialects/db2.py new file mode 100644 index 0000000000..ee8e0ca98a --- /dev/null +++ b/sqlglot/dialects/db2.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator, tokens, transforms +from sqlglot.dialects.dialect import ( + Dialect, + NormalizationStrategy, + max_or_greatest, + min_or_least, + rename_func, + trim_sql, + no_ilike_sql, + no_pivot_sql, + no_trycast_sql, +) +from sqlglot.parsers.db2 import Db2Parser +from sqlglot.tokens import TokenType + +if t.TYPE_CHECKING: + pass + + +def _date_add_sql( + kind: str, +) -> t.Callable[[Db2.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Db2.Generator, expression: exp.DateAdd | exp.DateSub) -> str: + this = self.sql(expression, "this") + unit = expression.args.get("unit") + value = self._simplify_unless_literal(expression.expression) + + if not isinstance(value, exp.Literal): + self.unsupported("Cannot add non literal") + + value_sql = self.sql(value) + unit_sql = self.sql(unit) if unit else "DAY" + + return f"{this} {kind} {value_sql} {unit_sql}" + + return func + + +class Db2(Dialect): + # DB2 is case-insensitive by default for unquoted identifiers + NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE + + # DB2 supports NULL ordering + NULL_ORDERING = "nulls_are_large" + + # DB2 specific settings + TYPED_DIVISION = True + SAFE_DIVISION = True + + # Time format mappings for DB2 + # https://www.ibm.com/docs/en/db2/11.5?topic=functions-timestamp-format + TIME_MAPPING = { + "YYYY": "%Y", + "YY": "%y", + "MM": "%m", + "DD": "%d", + "HH": "%H", + "HH12": "%I", + "HH24": "%H", + "MI": "%M", + "SS": "%S", + "FF": "%f", + "FF3": "%f", + "FF6": "%f", + "MON": "%b", + "MONTH": "%B", + "DY": "%a", + "DAY": "%A", + } + + class Tokenizer(tokens.Tokenizer): + # DB2 uses @ for variables + VAR_SINGLE_TOKENS = {"@"} + + # DB2 specific keywords + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "CHAR": TokenType.CHAR, + "CLOB": TokenType.TEXT, + "DBCLOB": TokenType.TEXT, + "DECFLOAT": TokenType.DECIMAL, + "GRAPHIC": TokenType.NCHAR, + "VARGRAPHIC": TokenType.NVARCHAR, + "SMALLINT": TokenType.SMALLINT, + "INTEGER": TokenType.INT, + "BIGINT": TokenType.BIGINT, + "REAL": TokenType.FLOAT, + "DOUBLE": TokenType.DOUBLE, + "DECIMAL": TokenType.DECIMAL, + "NUMERIC": TokenType.DECIMAL, + "VARCHAR": TokenType.VARCHAR, + "TIMESTAMP": TokenType.TIMESTAMP, + "TIMESTMP": TokenType.TIMESTAMP, + "SYSIBM": TokenType.SCHEMA, + "SYSFUN": TokenType.SCHEMA, + "SYSTOOLS": TokenType.SCHEMA, + } + + Parser = Db2Parser + + class Generator(generator.Generator): + LIMIT_FETCH = "FETCH" + JOIN_HINTS = False + TABLE_HINTS = False + QUERY_HINTS = False + NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False + + # DB2 uses CONCAT operator + CONCAT_COALESCE = True + + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "SMALLINT", + exp.DataType.Type.TINYINT: "SMALLINT", + exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", + exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.NCHAR: "GRAPHIC", + exp.DataType.Type.NVARCHAR: "VARGRAPHIC", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.DATETIME: "TIMESTAMP", + } + + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.ArgMax: rename_func("MAX"), + exp.ArgMin: rename_func("MIN"), + exp.DateAdd: _date_add_sql("+"), + exp.DateSub: _date_add_sql("-"), + exp.DateDiff: lambda self, + e: f"{self.func('DAYS', e.this)} - {self.func('DAYS', e.expression)}", + exp.CurrentDate: lambda self, e: "CURRENT DATE", + exp.CurrentTimestamp: lambda self, e: "CURRENT TIMESTAMP", + exp.ILike: no_ilike_sql, + exp.Max: max_or_greatest, + exp.Min: min_or_least, + exp.Pivot: no_pivot_sql, + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.StrPosition: rename_func("POSSTR"), + exp.TimeToStr: rename_func("VARCHAR_FORMAT"), + exp.TryCast: no_trycast_sql, + exp.Trim: trim_sql, + } + + def extract_sql(self, expression: exp.Extract) -> str: + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + + if this.upper() in ("DAYOFWEEK", "DAYOFYEAR"): + return f"{this.upper()}({expression_sql})" + + return f"EXTRACT({this} FROM {expression_sql})" + + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" + + def fetch_sql(self, expression: exp.Fetch) -> str: + count = expression.args.get("count") + if count: + return f" FETCH FIRST {self.sql(count)} ROWS ONLY" + return " FETCH FIRST ROW ONLY" + + def boolean_sql(self, expression: exp.Boolean) -> str: + return "1" if expression.this else "0" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index af232b5c7f..ff62629d3a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -86,6 +86,7 @@ class Dialects(str, Enum): BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" DATABRICKS = "databricks" + DB2 = "db2" DORIS = "doris" DREMIO = "dremio" DRILL = "drill" diff --git a/sqlglot/parsers/db2.py b/sqlglot/parsers/db2.py new file mode 100644 index 0000000000..3250f772b4 --- /dev/null +++ b/sqlglot/parsers/db2.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from sqlglot import exp, parser +from sqlglot.helper import seq_get + + +class Db2Parser(parser.Parser): + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "CHAR": exp.Cast.from_arg_list, + "DAYOFWEEK": lambda args: exp.Extract( + this=exp.var("DAYOFWEEK"), expression=seq_get(args, 0) + ), + "DAYOFYEAR": lambda args: exp.Extract( + this=exp.var("DAYOFYEAR"), expression=seq_get(args, 0) + ), + "POSSTR": lambda args: exp.StrPosition(this=seq_get(args, 0), substr=seq_get(args, 1)), + "VARCHAR_FORMAT": lambda args: exp.TimeToStr( + this=seq_get(args, 0), format=seq_get(args, 1) + ), + } diff --git a/tests/dialects/test_db2.py b/tests/dialects/test_db2.py new file mode 100644 index 0000000000..e0064c4d1a --- /dev/null +++ b/tests/dialects/test_db2.py @@ -0,0 +1,182 @@ +from tests.dialects.test_dialect import Validator + + +class TestDB2(Validator): + dialect = "db2" + + def test_db2(self): + # Test basic identity + self.validate_identity("SELECT FROM table1") + self.validate_identity("SELECT a, b, c FROM table1") + + # Test DB2 specific data types + self.validate_identity("CREATE TABLE t (a SMALLINT, b INT, c BIGINT)") + self.validate_identity("CREATE TABLE t (a CHAR(10), b VARCHAR(100))") + self.validate_identity("CREATE TABLE t (a DECIMAL(10, 2))") + self.validate_identity("CREATE TABLE t (a TIMESTAMP)") + + # Test FETCH FIRST syntax + self.validate_identity("SELECT * FROM t FETCH FIRST 10 ROWS ONLY") + self.validate_identity("SELECT * FROM t FETCH FIRST ROW ONLY") + + # Test OFFSET syntax + self.validate_identity("SELECT * FROM t OFFSET 5 ROWS") + self.validate_identity("SELECT * FROM t OFFSET 5 ROWS FETCH FIRST 10 ROWS ONLY") + + # Test CURRENT_DATE and CURRENT_TIMESTAMP + self.validate_all( + "SELECT CURRENT_DATE", + write={ + "db2": "SELECT CURRENT DATE", + }, + ) + self.validate_all( + "SELECT CURRENT_TIMESTAMP", + write={ + "db2": "SELECT CURRENT TIMESTAMP", + }, + ) + + # Test concatenation with || + self.validate_identity("SELECT a || b FROM t") + self.validate_identity("SELECT a || b || c FROM t") + + # Test POSSTR function (DB2's string position function) + self.validate_all( + "SELECT STRPOS(haystack, needle)", + write={ + "db2": "SELECT POSSTR(haystack, needle)", + }, + ) + + # Test boolean conversion (DB2 uses 0/1 for boolean) + self.validate_all( + "SELECT TRUE, FALSE", + write={ + "db2": "SELECT 1, 0", + }, + ) + + # Test DAYOFWEEK and DAYOFYEAR extracts + self.validate_all( + "SELECT EXTRACT(DAYOFWEEK FROM date_col)", + write={ + "db2": "SELECT DAYOFWEEK(date_col)", + }, + ) + + self.validate_all( + "SELECT EXTRACT(DAYOFYEAR FROM date_col)", + write={ + "db2": "SELECT DAYOFYEAR(date_col)", + }, + ) + + # Test VARCHAR_FORMAT (DB2's time to string function) + self.validate_all( + "SELECT TIME_TO_STR(timestamp_col, 'YYYY-MM-DD')", + write={ + "db2": "SELECT VARCHAR_FORMAT(timestamp_col, 'YYYY-MM-DD')", + }, + ) + + # Test DATEDIFF conversion + self.validate_all( + "SELECT DATEDIFF(date1, date2)", + write={ + "db2": "SELECT DAYS(date1) - DAYS(date2)", + }, + ) + + # Test joins + self.validate_identity("SELECT * FROM t1 INNER JOIN t2 ON t1.id = t2.id") + self.validate_identity("SELECT * FROM t1 LEFT JOIN t2 ON t1.id = t2.id") + self.validate_identity("SELECT * FROM t1 RIGHT JOIN t2 ON t1.id = t2.id") + + # Test subqueries + self.validate_identity("SELECT * FROM (SELECT a, b FROM t1) AS subq") + + # Test aggregations + self.validate_identity("SELECT COUNT(*) FROM t") + self.validate_identity("SELECT SUM(amount) FROM t") + self.validate_identity("SELECT AVG(amount) FROM t") + self.validate_identity("SELECT MIN(amount), MAX(amount) FROM t") + + # Test GROUP BY and HAVING + self.validate_identity("SELECT category, COUNT(*) FROM t GROUP BY category") + self.validate_identity( + "SELECT category, COUNT(*) FROM t GROUP BY category HAVING COUNT(*) > 5" + ) + + # Test ORDER BY + self.validate_identity("SELECT * FROM t ORDER BY a") + self.validate_identity("SELECT * FROM t ORDER BY a DESC") + self.validate_identity("SELECT * FROM t ORDER BY a, b DESC") + + # Test CASE expressions + self.validate_identity("SELECT CASE WHEN a > 0 THEN 'positive' ELSE 'negative' END FROM t") + self.validate_identity( + "SELECT CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM t" + ) + + # Test IN clause + self.validate_identity("SELECT * FROM t WHERE a IN (1, 2, 3)") + self.validate_all( + "SELECT * FROM t WHERE a NOT IN (1, 2, 3)", + write={ + "db2": "SELECT * FROM t WHERE NOT a IN (1, 2, 3)", + }, + ) + + # Test BETWEEN + self.validate_identity("SELECT * FROM t WHERE a BETWEEN 1 AND 10") + + # Test LIKE + self.validate_identity("SELECT * FROM t WHERE name LIKE 'John%'") + + # Test NULL handling + self.validate_identity("SELECT * FROM t WHERE a IS NULL") + self.validate_all( + "SELECT * FROM t WHERE a IS NOT NULL", + write={ + "db2": "SELECT * FROM t WHERE NOT a IS NULL", + }, + ) + self.validate_identity("SELECT COALESCE(a, b, c) FROM t") + + # Test UNION + self.validate_identity("SELECT a FROM t1 UNION SELECT a FROM t2") + self.validate_identity("SELECT a FROM t1 UNION ALL SELECT a FROM t2") + + # Test WITH (CTE) + self.validate_identity("WITH cte AS (SELECT * FROM t1) SELECT * FROM cte") + + # Test INSERT + self.validate_identity("INSERT INTO t (a, b) VALUES (1, 2)") + + # Test UPDATE + self.validate_identity("UPDATE t SET a = 1 WHERE b = 2") + + # Test DELETE + self.validate_identity("DELETE FROM t WHERE a = 1") + + # Test CREATE TABLE + self.validate_identity("CREATE TABLE t (id INT, name VARCHAR(100))") + + # Test DROP TABLE + self.validate_identity("DROP TABLE t") + + # Test MAX/MIN with GREATEST/LEAST + self.validate_all( + "SELECT MAX(a, b, c)", + write={ + "db2": "SELECT GREATEST(a, b, c)", + }, + ) + + self.validate_all( + "SELECT MIN(a, b, c)", + write={ + "db2": "SELECT LEAST(a, b, c)", + }, + )