diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index 66759dc440..795bc61f42 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -2,6 +2,7 @@ import logging import typing as t +from functools import reduce from sqlglot import exp, parse_one @@ -21,7 +22,7 @@ ) if t.TYPE_CHECKING: - from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core._typing import QueryOrDF, SchemaName, TableName logger = logging.getLogger(__name__) @@ -186,5 +187,84 @@ def _create_table_like( ) ) + def _replace_by_key( + self, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + key: t.Sequence[exp.Expr], + is_unique_key: bool, + source_columns: t.Optional[t.List[str]] = None, + ) -> None: + if len(key) <= 1: + return super()._replace_by_key( + target_table, source_table, target_columns_to_types, key, is_unique_key, source_columns + ) + + if target_columns_to_types is None: + target_columns_to_types = self.columns(target_table) + + temp_table = self._get_temp_table(target_table) + column_names = list(target_columns_to_types or []) + + target_alias = "_target" + temp_alias = "_temp" + + with self.transaction(): + self.ctas( + temp_table, + source_table, + target_columns_to_types=target_columns_to_types, + exists=False, + source_columns=source_columns, + ) + + try: + # Build a JOIN-based DELETE instead of using CONCAT_WS. + # CONCAT_WS prevents MySQL/MariaDB from using indexes, causing full table scans. + on_condition = reduce( + lambda a, b: exp.And(this=a, expression=b), + [ + self._qualify_columns(k, target_alias).eq( + self._qualify_columns(k, temp_alias) + ) + for k in key + ], + ) + + target_table_aliased = exp.to_table(target_table).as_(target_alias, quoted=True) + temp_table_aliased = exp.to_table(temp_table).as_(temp_alias, quoted=True) + + join = exp.Join(this=temp_table_aliased, kind="INNER", on=on_condition) + target_table_aliased.append("joins", join) + + delete_stmt = exp.Delete( + tables=[exp.to_table(target_alias)], + this=target_table_aliased, + ) + self.execute(delete_stmt) + + insert_query = self._select_columns(target_columns_to_types).from_(temp_table) + if is_unique_key: + insert_query = insert_query.distinct(*key) + + insert_statement = exp.insert( + insert_query, + target_table, + columns=column_names, + ) + self.execute(insert_statement, track_rows_processed=True) + finally: + self.drop_table(temp_table) + + @staticmethod + def _qualify_columns(expr: exp.Expr, table_alias: str) -> exp.Expr: + """Qualify unqualified column references in an expression with a table alias.""" + expr = expr.copy() + for col in expr.find_all(exp.Column): + if not col.table: + col.set("table", exp.to_identifier(table_alias, quoted=True)) + return expr + def ping(self) -> None: self._connection_pool.get().ping(reconnect=False) diff --git a/tests/core/engine_adapter/test_mysql.py b/tests/core/engine_adapter/test_mysql.py index f9fe140892..58513319a0 100644 --- a/tests/core/engine_adapter/test_mysql.py +++ b/tests/core/engine_adapter/test_mysql.py @@ -1,5 +1,6 @@ # type: ignore import typing as t +from unittest.mock import call from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one @@ -84,3 +85,78 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable): adapter.cursor.execute.assert_called_once_with( "CREATE TABLE IF NOT EXISTS `target_table` LIKE `source_table`" ) + + +def test_replace_by_key_composite_uses_join_delete( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + """Composite key DELETE uses JOIN instead of CONCAT_WS to allow index usage.""" + adapter = make_mocked_engine_adapter(MySQLEngineAdapter) + temp_table_mock = mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter._get_temp_table" + ) + temp_table_mock.return_value = exp.to_table("temporary") + + adapter.merge( + target_table="target", + source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")), + target_columns_to_types={ + "id": exp.DataType(this=exp.DataType.Type.INT), + "ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP), + "val": exp.DataType(this=exp.DataType.Type.INT), + }, + unique_key=[parse_one("id"), parse_one("ts")], + ) + + sql_calls = to_sql_calls(adapter) + + # The DELETE should use a JOIN instead of CONCAT_WS + assert any("CONCAT_WS" in s for s in sql_calls) is False, ( + "DELETE should not use CONCAT_WS for composite keys" + ) + assert any("INNER JOIN" in s for s in sql_calls) is True, ( + "DELETE should use INNER JOIN for composite keys" + ) + + # Verify the full sequence of SQL calls + adapter.cursor.execute.assert_has_calls( + [ + call( + "CREATE TABLE `temporary` AS SELECT CAST(`id` AS SIGNED) AS `id`, CAST(`ts` AS DATETIME) AS `ts`, CAST(`val` AS SIGNED) AS `val` FROM (SELECT `id`, `ts`, `val` FROM `source`) AS `_subquery`" + ), + call( + "DELETE `_target` FROM `target` AS `_target` INNER JOIN `temporary` AS `_temp` ON `_target`.`id` = `_temp`.`id` AND `_target`.`ts` = `_temp`.`ts`" + ), + call( + "INSERT INTO `target` (`id`, `ts`, `val`) SELECT `id`, `ts`, `val` FROM (SELECT `id` AS `id`, `ts` AS `ts`, `val` AS `val`, ROW_NUMBER() OVER (PARTITION BY `id`, `ts` ORDER BY `id`, `ts`) AS _row_number FROM `temporary`) AS _t WHERE _row_number = 1" + ), + call("DROP TABLE IF EXISTS `temporary`"), + ] + ) + + +def test_replace_by_key_single_key_uses_in( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + """Single key DELETE still uses the IN-based approach (indexes work fine for single column).""" + adapter = make_mocked_engine_adapter(MySQLEngineAdapter) + temp_table_mock = mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter._get_temp_table" + ) + temp_table_mock.return_value = exp.to_table("temporary") + + adapter.merge( + target_table="target", + source_table=t.cast(exp.Select, parse_one("SELECT id, val FROM source")), + target_columns_to_types={ + "id": exp.DataType(this=exp.DataType.Type.INT), + "val": exp.DataType(this=exp.DataType.Type.INT), + }, + unique_key=[parse_one("id")], + ) + + sql_calls = to_sql_calls(adapter) + + # Single key should use IN-based approach, not JOIN + assert any("IN" in s and "DELETE" in s for s in sql_calls) is True + assert any("INNER JOIN" in s for s in sql_calls) is False