diff --git a/sync_diff_inspector/utils/utils.go b/sync_diff_inspector/utils/utils.go index 382bfe02d..2470c72e5 100644 --- a/sync_diff_inspector/utils/utils.go +++ b/sync_diff_inspector/utils/utils.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb-tools/sync_diff_inspector/chunk" "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" "go.uber.org/zap" ) @@ -135,6 +136,27 @@ func GetColumnsFromIndex(index *model.IndexInfo, tableInfo *model.TableInfo) []* return indexColumns } +// formatColumn gets the SQL expression to format the column for comparison. +// +// Normally it would just return `name`, but for floating-point types it would +// round the result to 6 (single) or 15 (double) digits. +// See +// for details. +func formatColumn(name string, fieldType *types.FieldType) string { + var k, epsilon string + switch fieldType.GetType() { + case mysql.TypeFloat: + k = "5" + epsilon = "1e-45" + case mysql.TypeDouble: + k = "14" + epsilon = "5e-324" + default: + return name + } + return fmt.Sprintf("round(%[1]s, %[2]s-floor(log10(greatest(abs(%[1]s), %[3]s))))", name, k, epsilon) +} + // GetTableRowsQueryFormat returns a rowsQuerySQL template for the specific table. // // e.g. SELECT /*!40001 SQL_NO_CACHE */ `a`, `b` FROM `schema`.`table` WHERE %s ORDER BY `a`. @@ -148,14 +170,11 @@ func GetTableRowsQueryFormat(schema, table string, tableInfo *model.TableInfo, c } name := dbutil.ColumnName(col.Name.O) - // When col value is 0, the result is NULL. - // But we can use ISNULL to distinguish between null and 0. - if col.FieldType.GetType() == mysql.TypeFloat { - name = fmt.Sprintf("round(%s, 5-floor(log10(abs(%s)))) as %s", name, name, name) - } else if col.FieldType.GetType() == mysql.TypeDouble { - name = fmt.Sprintf("round(%s, 14-floor(log10(abs(%s)))) as %s", name, name, name) + expr := formatColumn(name, &col.FieldType) + if expr != name { + expr += " as " + name } - columnNames = append(columnNames, name) + columnNames = append(columnNames, expr) } columns := strings.Join(columnNames, ", ") if collation != "" { @@ -785,13 +804,7 @@ func GetCountAndMd5Checksum(ctx context.Context, db *sql.DB, schemaName, tableNa continue } name := dbutil.ColumnName(col.Name.O) - // When col value is 0, the result is NULL. - // But we can use ISNULL to distinguish between null and 0. - if col.FieldType.GetType() == mysql.TypeFloat { - name = fmt.Sprintf("round(%s, 5-floor(log10(abs(%s))))", name, name) - } else if col.FieldType.GetType() == mysql.TypeDouble { - name = fmt.Sprintf("round(%s, 14-floor(log10(abs(%s))))", name, name) - } + name = formatColumn(name, &col.FieldType) columnNames = append(columnNames, name) columnIsNull = append(columnIsNull, fmt.Sprintf("ISNULL(%s)", name)) } diff --git a/sync_diff_inspector/utils/utils_test.go b/sync_diff_inspector/utils/utils_test.go index 2131fa22c..a49fca11a 100644 --- a/sync_diff_inspector/utils/utils_test.go +++ b/sync_diff_inspector/utils/utils_test.go @@ -84,7 +84,7 @@ func TestBasicTableUtilOperation(t *testing.T) { require.NoError(t, err) query, orderKeyCols := GetTableRowsQueryFormat("test", "test", tableInfo, "123") - require.Equal(t, query, "SELECT /*!40001 SQL_NO_CACHE */ `a`, `b`, round(`c`, 5-floor(log10(abs(`c`)))) as `c`, `d` FROM `test`.`test` WHERE %s ORDER BY `a`,`b` COLLATE '123'") + require.Equal(t, query, "SELECT /*!40001 SQL_NO_CACHE */ `a`, `b`, round(`c`, 5-floor(log10(greatest(abs(`c`), 1e-45)))) as `c`, `d` FROM `test`.`test` WHERE %s ORDER BY `a`,`b` COLLATE '123'") expectName := []string{"a", "b"} for i, col := range orderKeyCols { require.Equal(t, col.Name.O, expectName[i])