From 5f66f0896312698c982367f64a482a166ac81909 Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Mon, 22 Dec 2025 21:03:48 -0500 Subject: [PATCH 1/3] Fix the information_schema queries. BigQuery requires that `information_schema` queries include the dataset name. This fixes the various Migrator methods to do this. --- driver/connection.go | 29 +++++++++-------- driver/driver.go | 2 +- driver/statement.go | 5 +-- driver/transaction.go | 2 +- migrator.go | 75 ++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 90 insertions(+), 23 deletions(-) diff --git a/driver/connection.go b/driver/connection.go index 2d8303a..a27f4f0 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -1,13 +1,14 @@ package driver import ( - "cloud.google.com/go/bigquery" "context" "database/sql/driver" "fmt" + + "cloud.google.com/go/bigquery" ) -type bigQueryConnection struct { +type BigQueryConnection struct { ctx context.Context client *bigquery.Client config bigQueryConfig @@ -16,7 +17,7 @@ type bigQueryConnection struct { dataset *bigquery.Dataset } -func (connection *bigQueryConnection) GetDataset() *bigquery.Dataset { +func (connection *BigQueryConnection) GetDataset() *bigquery.Dataset { if connection.dataset != nil { return connection.dataset } @@ -24,11 +25,11 @@ func (connection *bigQueryConnection) GetDataset() *bigquery.Dataset { return connection.dataset } -func (connection *bigQueryConnection) GetContext() context.Context { +func (connection *BigQueryConnection) GetContext() context.Context { return connection.ctx } -func (connection *bigQueryConnection) Ping(ctx context.Context) error { +func (connection *BigQueryConnection) Ping(ctx context.Context) error { dataset := connection.GetDataset() if dataset == nil { @@ -43,12 +44,12 @@ func (connection *bigQueryConnection) Ping(ctx context.Context) error { return nil } -func (connection *bigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func (connection *BigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { var statement = &bigQueryStatement{connection, query} return statement.QueryContext(ctx, args) } -func (connection *bigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) { +func (connection *BigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) { statement, err := connection.Prepare(query) if err != nil { return nil, nil @@ -57,13 +58,13 @@ func (connection *bigQueryConnection) Query(query string, args []driver.Value) ( return statement.Query(args) } -func (connection *bigQueryConnection) Prepare(query string) (driver.Stmt, error) { +func (connection *BigQueryConnection) Prepare(query string) (driver.Stmt, error) { var statement = &bigQueryStatement{connection, query} return statement, nil } -func (connection *bigQueryConnection) Close() error { +func (connection *BigQueryConnection) Close() error { if connection.closed { return nil } @@ -74,27 +75,27 @@ func (connection *bigQueryConnection) Close() error { return connection.client.Close() } -func (connection *bigQueryConnection) Begin() (driver.Tx, error) { +func (connection *BigQueryConnection) Begin() (driver.Tx, error) { var transaction = &bigQueryTransaction{connection} return transaction, nil } -func (connection *bigQueryConnection) query(query string) (*bigquery.Query, error) { +func (connection *BigQueryConnection) query(query string) (*bigquery.Query, error) { return connection.client.Query(query), nil } -func (connection *bigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (connection *BigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { var statement = &bigQueryStatement{connection, query} return statement.ExecContext(ctx, args) } -func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { +func (connection *BigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { var statement = &bigQueryStatement{connection, query} return statement.Exec(args) } -func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error { +func (BigQueryConnection) CheckNamedValue(*driver.NamedValue) error { // TODO: Revise in the future return nil } diff --git a/driver/driver.go b/driver/driver.go index fff8696..d11a1c8 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -57,7 +57,7 @@ func (b bigQueryDriver) Open(uri string) (driver.Conn, error) { return nil, err } - return &bigQueryConnection{ + return &BigQueryConnection{ ctx: ctx, client: client, config: *config, diff --git a/driver/statement.go b/driver/statement.go index dd28080..14ef201 100644 --- a/driver/statement.go +++ b/driver/statement.go @@ -1,16 +1,17 @@ package driver import ( - "cloud.google.com/go/bigquery" "context" "database/sql/driver" "errors" + + "cloud.google.com/go/bigquery" "github.com/sirupsen/logrus" "gorm.io/driver/bigquery/adaptor" ) type bigQueryStatement struct { - connection *bigQueryConnection + connection *BigQueryConnection query string } diff --git a/driver/transaction.go b/driver/transaction.go index 66b2539..e53148e 100644 --- a/driver/transaction.go +++ b/driver/transaction.go @@ -1,7 +1,7 @@ package driver type bigQueryTransaction struct { - connection *bigQueryConnection + connection *BigQueryConnection } func (transaction *bigQueryTransaction) Commit() error { diff --git a/migrator.go b/migrator.go index be7996d..5bb7a60 100644 --- a/migrator.go +++ b/migrator.go @@ -1,7 +1,11 @@ package bigquery import ( + "context" "errors" + "fmt" + + "gorm.io/driver/bigquery/driver" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -13,8 +17,11 @@ type Migrator struct { } func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) - return + datasetID, err := m.getDatasetID() + if err != nil { + return "" + } + return datasetID } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { @@ -40,7 +47,15 @@ func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM `INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count) + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw("SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count) }) return count > 0 @@ -67,8 +82,17 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { name = field.DBName } + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", + "SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.columns` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", stmt.Table, name, ).Row().Scan(&count) }) @@ -79,11 +103,52 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?", + "SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.table_constraints` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?", stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } + +// getDatasetID is a helper function to get the dataset ID from the connection. +func (m Migrator) getDatasetID() (string, error) { + sqlDB, err := m.DB.DB() + if err != nil { + return "", fmt.Errorf("could not get underlying database: %w", err) + } + ctx := context.Background() + conn, err := sqlDB.Conn(ctx) + if err != nil { + return "", fmt.Errorf("could not get connection: %w", err) + } + + datasetID := "" + err = conn.Raw(func(rawConnection any) error { + bigQueryConnection, ok := rawConnection.(*driver.BigQueryConnection) + if !ok { + return errors.New("connection is not a *driver.BigQueryConnection") + } + dataset := bigQueryConnection.GetDataset() + if dataset == nil { + return errors.New("dataset is nil") + } + datasetID = dataset.DatasetID + return nil + }) + if err != nil { + return "", fmt.Errorf("could not get dataset ID: %w", err) + } + + return datasetID, nil +} From de04626904f872931ef264b739ee84b7bdfbac0f Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Tue, 23 Dec 2025 10:00:23 -0500 Subject: [PATCH 2/3] Also support the comment --- migrator.go | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/migrator.go b/migrator.go index 5bb7a60..1918691 100644 --- a/migrator.go +++ b/migrator.go @@ -4,10 +4,13 @@ import ( "context" "errors" "fmt" + "slices" + "strings" "gorm.io/driver/bigquery/driver" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) @@ -121,6 +124,41 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } +// FullDataTypeOf returns field's db full data type +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + options := map[string]string{} + if field.Comment != "" { + options["description"] = field.Comment + } + + if len(options) > 0 { + optionParts := []string{} + for key, value := range options { + optionParts = append(optionParts, fmt.Sprintf("%s = %s", key, logger.ExplainSQL("?", nil, `'`, value))) + } + slices.Sort(optionParts) + expr.SQL += " OPTIONS (" + strings.Join(optionParts, " ") + ")" + } + + return +} + // getDatasetID is a helper function to get the dataset ID from the connection. func (m Migrator) getDatasetID() (string, error) { sqlDB, err := m.DB.DB() From 21f0025758a61ec6445601c34868ca5584bb3f95 Mon Sep 17 00:00:00 2001 From: Douglas Danger Manley Date: Tue, 23 Dec 2025 12:40:46 -0500 Subject: [PATCH 3/3] Add comma --- migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrator.go b/migrator.go index 1918691..dd6e844 100644 --- a/migrator.go +++ b/migrator.go @@ -153,7 +153,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { optionParts = append(optionParts, fmt.Sprintf("%s = %s", key, logger.ExplainSQL("?", nil, `'`, value))) } slices.Sort(optionParts) - expr.SQL += " OPTIONS (" + strings.Join(optionParts, " ") + ")" + expr.SQL += " OPTIONS (" + strings.Join(optionParts, ", ") + ")" } return