diff --git a/cmd/schema-test/migrations.go b/cmd/schema-test/migrations.go index c5c66d8..882b343 100644 --- a/cmd/schema-test/migrations.go +++ b/cmd/schema-test/migrations.go @@ -68,4 +68,13 @@ var allMigrations = []migrations.NamedMigration{ `DROP TYPE type1_old`, }), }, + { + Name: "Create a view referencing a new enum value", + Migration: migrations.StaticMigration([]string{ + `CREATE VIEW v AS SELECT * FROM table3 WHERE v = 'type1val3'`, + }), + Reverse: migrations.StaticMigration([]string{ + `DROP VIEW v`, + }), + }, } diff --git a/migrations/migration.go b/migrations/migration.go index bfcb116..b53fbd1 100644 --- a/migrations/migration.go +++ b/migrations/migration.go @@ -105,33 +105,10 @@ func verifyMigrations(tx *sqlx.Tx, migrations []NamedMigration) (firstUnappliedM return firstUnappliedMigrationIndex, nil } -func doMigrations(tx *sqlx.Tx, migrations []NamedMigration, startIndex int) error { - for index := startIndex; index < len(migrations); index++ { - migration := migrations[index] - log.Printf("Performing migration %d (%q)", index, migration.Name) - err := migrations[index].Migration.DoMigration(tx) - if err != nil { - return fmt.Errorf("Error performing migration %d (%q): %w", index, migration.Name, err) - } - _, err = tx.Exec(`INSERT INTO migration ("index", name) VALUES ($1, $2)`, index, migration.Name) - if err != nil { - return fmt.Errorf("Error recording migration %d (%q): %w", index, migration.Name, err) - } - } - - return nil -} - -// Rollback runs the Reverse migrations for all the input migrations with index >= rollBackThroughIndex. -// The input migrations must include all migrations, not just the ones to roll back. -func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) error { - err := ensureMigrationsTableExists(db) - if err != nil { - return err - } +func migrateOne(db *sqlx.DB, migrations []NamedMigration) (bool, error) { tx, err := db.Beginx() if err != nil { - return fmt.Errorf("Error starting migrations transaction: %w", err) + return false, fmt.Errorf("Error starting migrations transaction: %w", err) } committed := false defer func() { @@ -145,57 +122,40 @@ func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int _, err = tx.Exec("LOCK TABLE migration") if err != nil { - return fmt.Errorf("Error locking migration table: %w", err) + return false, fmt.Errorf("Error locking migration table: %w", err) } firstUnappliedIndex, err := verifyMigrations(tx, migrations) if err != nil { - return err + return false, err } - - if rollBackThroughIndex < 0 { - return fmt.Errorf("Invalid target index %d", rollBackThroughIndex) - } - if rollBackThroughIndex >= firstUnappliedIndex { - return fmt.Errorf("Migration %d has not been applied yet", rollBackThroughIndex) + if firstUnappliedIndex >= len(migrations) { + return false, nil } - for index := firstUnappliedIndex - 1; index >= rollBackThroughIndex; index-- { - migration := migrations[index] - if migration.Reverse == nil { - return fmt.Errorf("No Reverse for migration %d (%q)", index, migration.Name) - } - log.Printf("Reversing migration %d (%q)", index, migration.Name) - err := migrations[index].Reverse.DoMigration(tx) - if err != nil { - return fmt.Errorf("Error reversing migration %d (%q): %w", index, migration.Name, err) - } - _, err = tx.Exec(`DELETE FROM migration WHERE "index"=$1`, index) - if err != nil { - return fmt.Errorf("Error deleting migration row %d (%q): %w", index, migration.Name, err) - } + migration := migrations[firstUnappliedIndex] + log.Printf("Performing migration %d (%q)", firstUnappliedIndex, migration.Name) + err = migrations[firstUnappliedIndex].Migration.DoMigration(tx) + if err != nil { + return false, fmt.Errorf("Error performing migration %d (%q): %w", firstUnappliedIndex, migration.Name, err) + } + _, err = tx.Exec(`INSERT INTO migration ("index", name) VALUES ($1, $2)`, firstUnappliedIndex, migration.Name) + if err != nil { + return false, fmt.Errorf("Error recording migration %d (%q): %w", firstUnappliedIndex, migration.Name, err) } + committed = true err = tx.Commit() if err != nil { - return fmt.Errorf("Error committing migrations: %w", err) + return false, fmt.Errorf("Error committing migrations: %w", err) } - committed = true - return nil + return true, nil } -// Migrate does the following: -// 1. Verifies that the `migration` table exists, and creates it if it does not. -// 2. Verifies that the existing migrations recorded in the database match (by name and order) the migrations given as the argument. -// 3. Performs any migrations that are not yet recorded in the database. -func Migrate(db *sqlx.DB, migrations []NamedMigration) error { - err := ensureMigrationsTableExists(db) - if err != nil { - return err - } +func rollbackOne(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) (rolledBackIndex int, err error) { tx, err := db.Beginx() if err != nil { - return fmt.Errorf("Error starting migrations transaction: %w", err) + return -1, fmt.Errorf("Error starting migrations transaction: %w", err) } committed := false defer func() { @@ -209,23 +169,78 @@ func Migrate(db *sqlx.DB, migrations []NamedMigration) error { _, err = tx.Exec("LOCK TABLE migration") if err != nil { - return fmt.Errorf("Error locking migration table: %w", err) + return -1, fmt.Errorf("Error locking migration table: %w", err) } firstUnappliedIndex, err := verifyMigrations(tx, migrations) if err != nil { - return err + return -1, err } - err = doMigrations(tx, migrations, firstUnappliedIndex) + if rollBackThroughIndex < 0 { + return -1, fmt.Errorf("Invalid target index %d", rollBackThroughIndex) + } + if rollBackThroughIndex >= firstUnappliedIndex { + return -1, fmt.Errorf("Migration %d has not been applied yet", rollBackThroughIndex) + } + + index := firstUnappliedIndex - 1 + migration := migrations[index] + if migration.Reverse == nil { + return -1, fmt.Errorf("No Reverse for migration %d (%q)", index, migration.Name) + } + log.Printf("Reversing migration %d (%q)", index, migration.Name) + err = migrations[index].Reverse.DoMigration(tx) if err != nil { - return err + return -1, fmt.Errorf("Error reversing migration %d (%q): %w", index, migration.Name, err) + } + _, err = tx.Exec(`DELETE FROM migration WHERE "index"=$1`, index) + if err != nil { + return -1, fmt.Errorf("Error deleting migration row %d (%q): %w", index, migration.Name, err) } + committed = true err = tx.Commit() if err != nil { - return fmt.Errorf("Error committing migrations: %w", err) + return -1, fmt.Errorf("Error committing migrations: %w", err) + } + return index, nil +} + +// Rollback runs the Reverse migrations for all the input migrations with index >= rollBackThroughIndex. +// The input migrations must include all migrations, not just the ones to roll back. +func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) error { + err := ensureMigrationsTableExists(db) + if err != nil { + return err + } + for { + rolledBackIndex, err := rollbackOne(db, migrations, rollBackThroughIndex) + if err != nil { + return err + } + if rolledBackIndex == rollBackThroughIndex { + return nil + } + } +} + +// Migrate does the following: +// 1. Verifies that the `migration` table exists, and creates it if it does not. +// 2. Verifies that the existing migrations recorded in the database match (by name and order) the migrations given as the argument. +// 3. Performs any migrations that are not yet recorded in the database. +func Migrate(db *sqlx.DB, migrations []NamedMigration) error { + err := ensureMigrationsTableExists(db) + if err != nil { + return err + } + for { + migrated, err := migrateOne(db, migrations) + if err != nil { + return err + } + if !migrated { + return nil + } } - committed = true - return nil } diff --git a/migrations/verify.go b/migrations/verify.go index abbd94c..1ca1849 100644 --- a/migrations/verify.go +++ b/migrations/verify.go @@ -65,15 +65,61 @@ func verifyNoTables(db *sqlx.DB) error { return errors.New("Existing tables found. You must run SchemaTest on an empty database.") } +func migrateAndRollback(emptyDBConfig *PostgresConfig, db *sqlx.DB, allMigrations []NamedMigration, migrateToIndex, rollbackThroughIndex int, repeatForward bool) error { + beforeMigrate, err := dump(emptyDBConfig) + if err != nil { + return fmt.Errorf("Error calling pg_dump: %s", err) + } + err = Migrate(db, allMigrations[:migrateToIndex+1]) + if err != nil { + return fmt.Errorf("Migrate to %q failed: %s", allMigrations[migrateToIndex].Name, err) + } + afterMigrate, err := dump(emptyDBConfig) + if err != nil { + return fmt.Errorf("Error calling pg_dump: %s", err) + } + err = Rollback(db, allMigrations, rollbackThroughIndex) + if err != nil { + return fmt.Errorf("Rollback through %q failed: %s", allMigrations[rollbackThroughIndex].Name, err) + } + afterRollback, err := dump(emptyDBConfig) + if err != nil { + return fmt.Errorf("Error calling pg_dump: %s", err) + } + if string(beforeMigrate) != string(afterRollback) { + fmt.Printf("%s\n", cmp.Diff(string(beforeMigrate), string(afterRollback))) + return fmt.Errorf("Dump after rollback through %q did not match the dump before the migration", allMigrations[rollbackThroughIndex].Name) + } + if repeatForward { + err = Migrate(db, allMigrations[:migrateToIndex+1]) + if err != nil { + return fmt.Errorf("Migration to %q failed: %s", allMigrations[migrateToIndex].Name, err) + } + afterMigrateAgain, err := dump(emptyDBConfig) + if err != nil { + return fmt.Errorf("Error calling pg_dump: %s", err) + } + if string(afterMigrate) != string(afterMigrateAgain) { + fmt.Printf("%s\n", cmp.Diff(string(afterMigrate), string(afterMigrateAgain))) + return fmt.Errorf("Dump after re-migration of %q did not match dump after first migration", allMigrations[migrateToIndex].Name) + } + } + return err +} + // Schema test expects a new *empty* postgres database. -// It will, for each migration: -// 1. Apply the migration -// 2. Reverse the migration -// 3. Apply the migration again +// It will: +// 1. Apply all migrations +// 2. Reverse all migrations +// 3. For each migration: +// a. Apply the migration +// b. Reverse the migration +// c. Apply the migration again +// // Before and after each step it will use pg_dump to dump the database schema. // It will verify that: -// A. The schema is the same after step 2 as before step 1. -// B. The schema is the same after step 3 as after step 1. +// A. The schema is the same after reversing as before applying. +// B. (If re-applying) The schema is the same after applying as after re-applying. // // You must have `pg_dump` in your `PATH` to run this. func SchemaTest(emptyDBConfig *PostgresConfig, allMigrations []NamedMigration) error { @@ -103,42 +149,14 @@ func SchemaTest(emptyDBConfig *PostgresConfig, allMigrations []NamedMigration) e if err != nil { return fmt.Errorf("Setting up migrations table failed: %s", err) } - for idx, migration := range allMigrations { - beforeMigrate, err := dump(emptyDBConfig) - if err != nil { - return fmt.Errorf("Error calling pg_dump: %s", err) - } - err = Migrate(db, allMigrations[:idx+1]) - if err != nil { - return fmt.Errorf("Migration %q failed: %s", migration.Name, err) - } - afterMigrate, err := dump(emptyDBConfig) - if err != nil { - return fmt.Errorf("Error calling pg_dump: %s", err) - } - err = Rollback(db, allMigrations, idx) - if err != nil { - return fmt.Errorf("Rollback to %q failed: %s", migration.Name, err) - } - afterRollback, err := dump(emptyDBConfig) - if err != nil { - return fmt.Errorf("Error calling pg_dump: %s", err) - } - if string(beforeMigrate) != string(afterRollback) { - fmt.Printf("%s\n", cmp.Diff(string(beforeMigrate), string(afterRollback))) - return fmt.Errorf("Dump after rollback of %q did not match the dump before the migration", migration.Name) - } - err = Migrate(db, allMigrations[:idx+1]) - if err != nil { - return fmt.Errorf("Migration %q failed: %s", migration.Name, err) - } - afterMigrateAgain, err := dump(emptyDBConfig) + err = migrateAndRollback(emptyDBConfig, db, allMigrations, len(allMigrations)-1, 0, false) + if err != nil { + return err + } + for idx := range allMigrations { + err := migrateAndRollback(emptyDBConfig, db, allMigrations, idx, idx, true) if err != nil { - return fmt.Errorf("Error calling pg_dump: %s", err) - } - if string(afterMigrate) != string(afterMigrateAgain) { - fmt.Printf("%s\n", cmp.Diff(string(afterMigrate), string(afterMigrateAgain))) - return fmt.Errorf("Dump after re-migration of %q did not match dump after first migration", migration.Name) + return err } } return nil