Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion sqlmesh/core/engine_adapter/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import typing as t
from functools import reduce

from sqlglot import exp, parse_one

Expand All @@ -21,7 +22,7 @@
)

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
from sqlmesh.core._typing import QueryOrDF, SchemaName, TableName

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,5 +187,84 @@ def _create_table_like(
)
)

def _replace_by_key(
self,
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
key: t.Sequence[exp.Expr],
is_unique_key: bool,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
if len(key) <= 1:
return super()._replace_by_key(
target_table, source_table, target_columns_to_types, key, is_unique_key, source_columns
)

if target_columns_to_types is None:
target_columns_to_types = self.columns(target_table)

temp_table = self._get_temp_table(target_table)
column_names = list(target_columns_to_types or [])

target_alias = "_target"
temp_alias = "_temp"

with self.transaction():
self.ctas(
temp_table,
source_table,
target_columns_to_types=target_columns_to_types,
exists=False,
source_columns=source_columns,
)

try:
# Build a JOIN-based DELETE instead of using CONCAT_WS.
# CONCAT_WS prevents MySQL/MariaDB from using indexes, causing full table scans.
on_condition = reduce(
lambda a, b: exp.And(this=a, expression=b),
[
self._qualify_columns(k, target_alias).eq(
self._qualify_columns(k, temp_alias)
)
for k in key
],
)

target_table_aliased = exp.to_table(target_table).as_(target_alias, quoted=True)
temp_table_aliased = exp.to_table(temp_table).as_(temp_alias, quoted=True)

join = exp.Join(this=temp_table_aliased, kind="INNER", on=on_condition)
target_table_aliased.append("joins", join)

delete_stmt = exp.Delete(
tables=[exp.to_table(target_alias)],
this=target_table_aliased,
)
self.execute(delete_stmt)

insert_query = self._select_columns(target_columns_to_types).from_(temp_table)
if is_unique_key:
insert_query = insert_query.distinct(*key)

insert_statement = exp.insert(
insert_query,
target_table,
columns=column_names,
)
self.execute(insert_statement, track_rows_processed=True)
finally:
self.drop_table(temp_table)

@staticmethod
def _qualify_columns(expr: exp.Expr, table_alias: str) -> exp.Expr:
"""Qualify unqualified column references in an expression with a table alias."""
expr = expr.copy()
for col in expr.find_all(exp.Column):
if not col.table:
col.set("table", exp.to_identifier(table_alias, quoted=True))
return expr

def ping(self) -> None:
self._connection_pool.get().ping(reconnect=False)
76 changes: 76 additions & 0 deletions tests/core/engine_adapter/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import typing as t
from unittest.mock import call

from pytest_mock.plugin import MockerFixture
from sqlglot import exp, parse_one
Expand Down Expand Up @@ -84,3 +85,78 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable):
adapter.cursor.execute.assert_called_once_with(
"CREATE TABLE IF NOT EXISTS `target_table` LIKE `source_table`"
)


def test_replace_by_key_composite_uses_join_delete(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
):
"""Composite key DELETE uses JOIN instead of CONCAT_WS to allow index usage."""
adapter = make_mocked_engine_adapter(MySQLEngineAdapter)
temp_table_mock = mocker.patch(
"sqlmesh.core.engine_adapter.base.EngineAdapter._get_temp_table"
)
temp_table_mock.return_value = exp.to_table("temporary")

adapter.merge(
target_table="target",
source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")),
target_columns_to_types={
"id": exp.DataType(this=exp.DataType.Type.INT),
"ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP),
"val": exp.DataType(this=exp.DataType.Type.INT),
},
unique_key=[parse_one("id"), parse_one("ts")],
)

sql_calls = to_sql_calls(adapter)

# The DELETE should use a JOIN instead of CONCAT_WS
assert any("CONCAT_WS" in s for s in sql_calls) is False, (
"DELETE should not use CONCAT_WS for composite keys"
)
assert any("INNER JOIN" in s for s in sql_calls) is True, (
"DELETE should use INNER JOIN for composite keys"
)

# Verify the full sequence of SQL calls
adapter.cursor.execute.assert_has_calls(
[
call(
"CREATE TABLE `temporary` AS SELECT CAST(`id` AS SIGNED) AS `id`, CAST(`ts` AS DATETIME) AS `ts`, CAST(`val` AS SIGNED) AS `val` FROM (SELECT `id`, `ts`, `val` FROM `source`) AS `_subquery`"
),
call(
"DELETE `_target` FROM `target` AS `_target` INNER JOIN `temporary` AS `_temp` ON `_target`.`id` = `_temp`.`id` AND `_target`.`ts` = `_temp`.`ts`"
),
call(
"INSERT INTO `target` (`id`, `ts`, `val`) SELECT `id`, `ts`, `val` FROM (SELECT `id` AS `id`, `ts` AS `ts`, `val` AS `val`, ROW_NUMBER() OVER (PARTITION BY `id`, `ts` ORDER BY `id`, `ts`) AS _row_number FROM `temporary`) AS _t WHERE _row_number = 1"
),
call("DROP TABLE IF EXISTS `temporary`"),
]
)


def test_replace_by_key_single_key_uses_in(
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
):
"""Single key DELETE still uses the IN-based approach (indexes work fine for single column)."""
adapter = make_mocked_engine_adapter(MySQLEngineAdapter)
temp_table_mock = mocker.patch(
"sqlmesh.core.engine_adapter.base.EngineAdapter._get_temp_table"
)
temp_table_mock.return_value = exp.to_table("temporary")

adapter.merge(
target_table="target",
source_table=t.cast(exp.Select, parse_one("SELECT id, val FROM source")),
target_columns_to_types={
"id": exp.DataType(this=exp.DataType.Type.INT),
"val": exp.DataType(this=exp.DataType.Type.INT),
},
unique_key=[parse_one("id")],
)

sql_calls = to_sql_calls(adapter)

# Single key should use IN-based approach, not JOIN
assert any("IN" in s and "DELETE" in s for s in sql_calls) is True
assert any("INNER JOIN" in s for s in sql_calls) is False