Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmd/schema-test/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,13 @@ var allMigrations = []migrations.NamedMigration{
`DROP TYPE type1_old`,
}),
},
{
Name: "Create a view referencing a new enum value",
Migration: migrations.StaticMigration([]string{
`CREATE VIEW v AS SELECT * FROM table3 WHERE v = 'type1val3'`,
}),
Reverse: migrations.StaticMigration([]string{
`DROP VIEW v`,
}),
},
}
149 changes: 82 additions & 67 deletions migrations/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,10 @@ func verifyMigrations(tx *sqlx.Tx, migrations []NamedMigration) (firstUnappliedM
return firstUnappliedMigrationIndex, nil
}

func doMigrations(tx *sqlx.Tx, migrations []NamedMigration, startIndex int) error {
for index := startIndex; index < len(migrations); index++ {
migration := migrations[index]
log.Printf("Performing migration %d (%q)", index, migration.Name)
err := migrations[index].Migration.DoMigration(tx)
if err != nil {
return fmt.Errorf("Error performing migration %d (%q): %w", index, migration.Name, err)
}
_, err = tx.Exec(`INSERT INTO migration ("index", name) VALUES ($1, $2)`, index, migration.Name)
if err != nil {
return fmt.Errorf("Error recording migration %d (%q): %w", index, migration.Name, err)
}
}

return nil
}

// Rollback runs the Reverse migrations for all the input migrations with index >= rollBackThroughIndex.
// The input migrations must include all migrations, not just the ones to roll back.
func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) error {
err := ensureMigrationsTableExists(db)
if err != nil {
return err
}
func migrateOne(db *sqlx.DB, migrations []NamedMigration) (bool, error) {
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("Error starting migrations transaction: %w", err)
return false, fmt.Errorf("Error starting migrations transaction: %w", err)
}
committed := false
defer func() {
Expand All @@ -145,57 +122,40 @@ func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int

_, err = tx.Exec("LOCK TABLE migration")
if err != nil {
return fmt.Errorf("Error locking migration table: %w", err)
return false, fmt.Errorf("Error locking migration table: %w", err)
}

firstUnappliedIndex, err := verifyMigrations(tx, migrations)
if err != nil {
return err
return false, err
}

if rollBackThroughIndex < 0 {
return fmt.Errorf("Invalid target index %d", rollBackThroughIndex)
}
if rollBackThroughIndex >= firstUnappliedIndex {
return fmt.Errorf("Migration %d has not been applied yet", rollBackThroughIndex)
if firstUnappliedIndex >= len(migrations) {
return false, nil
}

for index := firstUnappliedIndex - 1; index >= rollBackThroughIndex; index-- {
migration := migrations[index]
if migration.Reverse == nil {
return fmt.Errorf("No Reverse for migration %d (%q)", index, migration.Name)
}
log.Printf("Reversing migration %d (%q)", index, migration.Name)
err := migrations[index].Reverse.DoMigration(tx)
if err != nil {
return fmt.Errorf("Error reversing migration %d (%q): %w", index, migration.Name, err)
}
_, err = tx.Exec(`DELETE FROM migration WHERE "index"=$1`, index)
if err != nil {
return fmt.Errorf("Error deleting migration row %d (%q): %w", index, migration.Name, err)
}
migration := migrations[firstUnappliedIndex]
log.Printf("Performing migration %d (%q)", firstUnappliedIndex, migration.Name)
err = migrations[firstUnappliedIndex].Migration.DoMigration(tx)
if err != nil {
return false, fmt.Errorf("Error performing migration %d (%q): %w", firstUnappliedIndex, migration.Name, err)
}
_, err = tx.Exec(`INSERT INTO migration ("index", name) VALUES ($1, $2)`, firstUnappliedIndex, migration.Name)
if err != nil {
return false, fmt.Errorf("Error recording migration %d (%q): %w", firstUnappliedIndex, migration.Name, err)
}

committed = true
err = tx.Commit()
if err != nil {
return fmt.Errorf("Error committing migrations: %w", err)
return false, fmt.Errorf("Error committing migrations: %w", err)
}
committed = true
return nil
return true, nil
}

// Migrate does the following:
// 1. Verifies that the `migration` table exists, and creates it if it does not.
// 2. Verifies that the existing migrations recorded in the database match (by name and order) the migrations given as the argument.
// 3. Performs any migrations that are not yet recorded in the database.
func Migrate(db *sqlx.DB, migrations []NamedMigration) error {
err := ensureMigrationsTableExists(db)
if err != nil {
return err
}
func rollbackOne(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) (rolledBackIndex int, err error) {
tx, err := db.Beginx()
if err != nil {
return fmt.Errorf("Error starting migrations transaction: %w", err)
return -1, fmt.Errorf("Error starting migrations transaction: %w", err)
}
committed := false
defer func() {
Expand All @@ -209,23 +169,78 @@ func Migrate(db *sqlx.DB, migrations []NamedMigration) error {

_, err = tx.Exec("LOCK TABLE migration")
if err != nil {
return fmt.Errorf("Error locking migration table: %w", err)
return -1, fmt.Errorf("Error locking migration table: %w", err)
}

firstUnappliedIndex, err := verifyMigrations(tx, migrations)
if err != nil {
return err
return -1, err
}

err = doMigrations(tx, migrations, firstUnappliedIndex)
if rollBackThroughIndex < 0 {
return -1, fmt.Errorf("Invalid target index %d", rollBackThroughIndex)
}
if rollBackThroughIndex >= firstUnappliedIndex {
return -1, fmt.Errorf("Migration %d has not been applied yet", rollBackThroughIndex)
}

index := firstUnappliedIndex - 1
migration := migrations[index]
if migration.Reverse == nil {
return -1, fmt.Errorf("No Reverse for migration %d (%q)", index, migration.Name)
}
log.Printf("Reversing migration %d (%q)", index, migration.Name)
err = migrations[index].Reverse.DoMigration(tx)
if err != nil {
return err
return -1, fmt.Errorf("Error reversing migration %d (%q): %w", index, migration.Name, err)
}
_, err = tx.Exec(`DELETE FROM migration WHERE "index"=$1`, index)
if err != nil {
return -1, fmt.Errorf("Error deleting migration row %d (%q): %w", index, migration.Name, err)
}

committed = true
err = tx.Commit()
if err != nil {
return fmt.Errorf("Error committing migrations: %w", err)
return -1, fmt.Errorf("Error committing migrations: %w", err)
}
return index, nil
}

// Rollback runs the Reverse migrations for all the input migrations with index >= rollBackThroughIndex.
// The input migrations must include all migrations, not just the ones to roll back.
func Rollback(db *sqlx.DB, migrations []NamedMigration, rollBackThroughIndex int) error {
err := ensureMigrationsTableExists(db)
if err != nil {
return err
}
for {
rolledBackIndex, err := rollbackOne(db, migrations, rollBackThroughIndex)
if err != nil {
return err
}
if rolledBackIndex == rollBackThroughIndex {
return nil
}
}
}

// Migrate does the following:
// 1. Verifies that the `migration` table exists, and creates it if it does not.
// 2. Verifies that the existing migrations recorded in the database match (by name and order) the migrations given as the argument.
// 3. Performs any migrations that are not yet recorded in the database.
func Migrate(db *sqlx.DB, migrations []NamedMigration) error {
err := ensureMigrationsTableExists(db)
if err != nil {
return err
}
for {
migrated, err := migrateOne(db, migrations)
if err != nil {
return err
}
if !migrated {
return nil
}
}
committed = true
return nil
}
100 changes: 59 additions & 41 deletions migrations/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,61 @@ func verifyNoTables(db *sqlx.DB) error {
return errors.New("Existing tables found. You must run SchemaTest on an empty database.")
}

func migrateAndRollback(emptyDBConfig *PostgresConfig, db *sqlx.DB, allMigrations []NamedMigration, migrateToIndex, rollbackThroughIndex int, repeatForward bool) error {
beforeMigrate, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
err = Migrate(db, allMigrations[:migrateToIndex+1])
if err != nil {
return fmt.Errorf("Migrate to %q failed: %s", allMigrations[migrateToIndex].Name, err)
}
afterMigrate, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
err = Rollback(db, allMigrations, rollbackThroughIndex)
if err != nil {
return fmt.Errorf("Rollback through %q failed: %s", allMigrations[rollbackThroughIndex].Name, err)
}
afterRollback, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
if string(beforeMigrate) != string(afterRollback) {
fmt.Printf("%s\n", cmp.Diff(string(beforeMigrate), string(afterRollback)))
return fmt.Errorf("Dump after rollback through %q did not match the dump before the migration", allMigrations[rollbackThroughIndex].Name)
}
if repeatForward {
err = Migrate(db, allMigrations[:migrateToIndex+1])
if err != nil {
return fmt.Errorf("Migration to %q failed: %s", allMigrations[migrateToIndex].Name, err)
}
afterMigrateAgain, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
if string(afterMigrate) != string(afterMigrateAgain) {
fmt.Printf("%s\n", cmp.Diff(string(afterMigrate), string(afterMigrateAgain)))
return fmt.Errorf("Dump after re-migration of %q did not match dump after first migration", allMigrations[migrateToIndex].Name)
}
}
return err
}

// Schema test expects a new *empty* postgres database.
// It will, for each migration:
// 1. Apply the migration
// 2. Reverse the migration
// 3. Apply the migration again
// It will:
// 1. Apply all migrations
// 2. Reverse all migrations
// 3. For each migration:
// a. Apply the migration
// b. Reverse the migration
// c. Apply the migration again
//
// Before and after each step it will use pg_dump to dump the database schema.
// It will verify that:
// A. The schema is the same after step 2 as before step 1.
// B. The schema is the same after step 3 as after step 1.
// A. The schema is the same after reversing as before applying.
// B. (If re-applying) The schema is the same after applying as after re-applying.
//
// You must have `pg_dump` in your `PATH` to run this.
func SchemaTest(emptyDBConfig *PostgresConfig, allMigrations []NamedMigration) error {
Expand Down Expand Up @@ -103,42 +149,14 @@ func SchemaTest(emptyDBConfig *PostgresConfig, allMigrations []NamedMigration) e
if err != nil {
return fmt.Errorf("Setting up migrations table failed: %s", err)
}
for idx, migration := range allMigrations {
beforeMigrate, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
err = Migrate(db, allMigrations[:idx+1])
if err != nil {
return fmt.Errorf("Migration %q failed: %s", migration.Name, err)
}
afterMigrate, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
err = Rollback(db, allMigrations, idx)
if err != nil {
return fmt.Errorf("Rollback to %q failed: %s", migration.Name, err)
}
afterRollback, err := dump(emptyDBConfig)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
if string(beforeMigrate) != string(afterRollback) {
fmt.Printf("%s\n", cmp.Diff(string(beforeMigrate), string(afterRollback)))
return fmt.Errorf("Dump after rollback of %q did not match the dump before the migration", migration.Name)
}
err = Migrate(db, allMigrations[:idx+1])
if err != nil {
return fmt.Errorf("Migration %q failed: %s", migration.Name, err)
}
afterMigrateAgain, err := dump(emptyDBConfig)
err = migrateAndRollback(emptyDBConfig, db, allMigrations, len(allMigrations)-1, 0, false)
if err != nil {
return err
}
for idx := range allMigrations {
err := migrateAndRollback(emptyDBConfig, db, allMigrations, idx, idx, true)
if err != nil {
return fmt.Errorf("Error calling pg_dump: %s", err)
}
if string(afterMigrate) != string(afterMigrateAgain) {
fmt.Printf("%s\n", cmp.Diff(string(afterMigrate), string(afterMigrateAgain)))
return fmt.Errorf("Dump after re-migration of %q did not match dump after first migration", migration.Name)
return err
}
}
return nil
Expand Down