diff --git a/migrate/helper_test.go b/migrate/helper_test.go deleted file mode 100644 index aa7095d9..00000000 --- a/migrate/helper_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package migrate_test - -import ( - "github.com/JackC/pgx" - "github.com/JackC/pgx/migrate" -) - -type test interface { - Fatalf(format string, args ...interface{}) -} - -func mustConnect(t test, connectionParameters *pgx.ConnectionParameters) (conn *pgx.Connection) { - var err error - conn, err = pgx.Connect(*connectionParameters) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - return -} - -func mustCreateMigrator(t test, conn *pgx.Connection) (m *migrate.Migrator) { - var err error - m, err = migrate.NewMigrator(conn, versionTable) - if err != nil { - t.Fatalf("Unable to create migrator: %v", err) - } - return -} - -func mustExecute(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (commandTag string) { - var err error - if commandTag, err = conn.Execute(sql, arguments...); err != nil { - t.Fatalf("Execute unexpectedly failed with %v: %v", sql, err) - } - return -} - -func mustSelectValue(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (value interface{}) { - var err error - if value, err = conn.SelectValue(sql, arguments...); err != nil { - t.Fatalf("SelectValue unexpectedly failed with %v: %v", sql, err) - } - return -} diff --git a/migrate/migrate.go b/migrate/migrate.go index 1c3020c9..81f4858e 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -14,14 +14,15 @@ func (e BadVersionError) Error() string { type Migration struct { Sequence int32 Name string - SQL string + UpSQL string + DownSQL string } type Migrator struct { conn *pgx.Connection versionTable string Migrations []*Migration - OnStart func(*Migration) `called when Migrate starts a migration` + OnStart func(*Migration, string) // OnStart is called when a migration is run with the migration and direction } func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err error) { @@ -31,50 +32,79 @@ func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err er return } -func (m *Migrator) AppendMigration(name, sql string) { - m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations) + 1), Name: name, SQL: sql}) +func (m *Migrator) AppendMigration(name, upSQL, downSQL string) { + m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations)) + 1, Name: name, UpSQL: upSQL, DownSQL: downSQL}) return } // Migrate runs pending migrations // It calls m.OnStart when it begins a migration func (m *Migrator) Migrate() error { - var done bool + return m.MigrateTo(int32(len(m.Migrations))) +} + +// MigrateTo migrates to targetVersion +func (m *Migrator) MigrateTo(targetVersion int32) (err error) { + // Lock to ensure multiple migrations cannot occur simultaneously + lockNum := int64(9628173550095224) // arbitrary random number + if _, lockErr := m.conn.Execute("select pg_advisory_lock($1)", lockNum); lockErr != nil { + return lockErr + } + defer func() { + _, unlockErr := m.conn.Execute("select pg_advisory_unlock($1)", lockNum) + if err == nil && unlockErr != nil { + err = unlockErr + } + }() + + currentVersion, err := m.GetCurrentVersion() + if err != nil { + return err + } + + if targetVersion < 0 || int32(len(m.Migrations)) < targetVersion { + errMsg := fmt.Sprintf("%s version %d is outside the valid versions of 0 to %d", m.versionTable, targetVersion, len(m.Migrations)) + return BadVersionError(errMsg) + } + + var direction int32 + if currentVersion < targetVersion { + direction = 1 + } else { + direction = -1 + } + + for currentVersion != targetVersion { + var current *Migration + var sql, directionName string + var sequence int32 + if direction == 1 { + current = m.Migrations[currentVersion] + sequence = current.Sequence + sql = current.UpSQL + directionName = "up" + } else { + current = m.Migrations[currentVersion-1] + sequence = current.Sequence - 1 + sql = current.DownSQL + directionName = "down" + } - for !done { var innerErr error - - var txErr error - _, txErr = m.conn.Transaction(func() bool { - // Lock version table for duration of transaction to ensure multiple migrations cannot occur simultaneously - if _, innerErr = m.conn.Execute("lock table " + m.versionTable); innerErr != nil { - return false - } - - // Get pending migrations - var pending []*Migration - if pending, innerErr = m.PendingMigrations(); innerErr != nil { - return false - } - - // If no migrations are pending set the done flag and return - if len(pending) == 0 { - done = true - return true - } + _, txErr := m.conn.Transaction(func() bool { // Fire on start callback if m.OnStart != nil { - m.OnStart(pending[0]) + m.OnStart(current, directionName) } - // Execute the first pending migration - if _, innerErr = m.conn.Execute(pending[0].SQL); innerErr != nil { + // Execute the migration + if _, innerErr = m.conn.Execute(sql); innerErr != nil { return false } // Add one to the version - if _, innerErr = m.conn.Execute("update " + m.versionTable + " set version=version+1"); innerErr != nil { + if _, innerErr = m.conn.Execute("update "+m.versionTable+" set version=$1", sequence); innerErr != nil { return false } @@ -88,28 +118,13 @@ func (m *Migrator) Migrate() error { if innerErr != nil { return innerErr } + + currentVersion = currentVersion + direction } return nil } -func (m *Migrator) PendingMigrations() ([]*Migration, error) { - if len(m.Migrations) == 0 { - return m.Migrations, nil - } - - if current, err := m.GetCurrentVersion(); err == nil { - current := int(current) - if current < 0 || len(m.Migrations) < current { - errMsg := fmt.Sprintf("%s version %d is outside the known migrations of 0 to %d", m.versionTable, current, len(m.Migrations)) - return nil, BadVersionError(errMsg) - } - return m.Migrations[current:len(m.Migrations)], nil - } else { - return nil, err - } -} - func (m *Migrator) GetCurrentVersion() (int32, error) { if v, err := m.conn.SelectValue("select version from " + m.versionTable); err == nil { return v.(int32), nil diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index dbe689d9..5616eb3e 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -4,148 +4,202 @@ import ( "fmt" "github.com/JackC/pgx" "github.com/JackC/pgx/migrate" + . "gopkg.in/check.v1" "testing" ) +type MigrateSuite struct { + conn *pgx.Connection +} + +func Test(t *testing.T) { TestingT(t) } + +var _ = Suite(&MigrateSuite{}) + var versionTable string = "schema_version" -func clearMigrate(t *testing.T, conn *pgx.Connection) { - tables := []string{versionTable, "t", "t1", "t2"} - for _, table := range tables { - mustExecute(t, conn, "drop table if exists "+table) - } +func (s *MigrateSuite) SetUpTest(c *C) { + var err error + s.conn, err = pgx.Connect(*defaultConnectionParameters) + c.Assert(err, IsNil) + + s.cleanupSampleMigrator(c) } -func TestNewMigrator(t *testing.T) { - conn := mustConnect(t, defaultConnectionParameters) - clearMigrate(t, conn) +func (s *MigrateSuite) SelectValue(c *C, sql string, arguments ...interface{}) interface{} { + value, err := s.conn.SelectValue(sql, arguments...) + c.Assert(err, IsNil) + return value +} - var m *migrate.Migrator - var err error - m, err = migrate.NewMigrator(conn, versionTable) - if err != nil { - t.Fatalf("Unable to create migrator: %v", err) - } +func (s *MigrateSuite) Execute(c *C, sql string, arguments ...interface{}) string { + commandTag, err := s.conn.Execute(sql, arguments...) + c.Assert(err, IsNil) + return commandTag +} - schemaVersionExists := mustSelectValue(t, - conn, +func (s *MigrateSuite) tableExists(c *C, tableName string) bool { + return s.SelectValue(c, "select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)", defaultConnectionParameters.Database, - versionTable).(bool) + tableName).(bool) +} - if !schemaVersionExists { - t.Fatalf("NewMigrator did not create %v table", versionTable) - } +func (s *MigrateSuite) createEmptyMigrator(c *C) *migrate.Migrator { + var err error + m, err := migrate.NewMigrator(s.conn, versionTable) + c.Assert(err, IsNil) + return m +} - m, err = migrate.NewMigrator(conn, versionTable) - if err != nil { - t.Fatalf("NewMigrator failed when %v table already exists: %v", versionTable, err) - } +func (s *MigrateSuite) createSampleMigrator(c *C) *migrate.Migrator { + m := s.createEmptyMigrator(c) + m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;") + m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;") + m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;") + return m +} - var initialVersion int32 - initialVersion, err = m.GetCurrentVersion() - if err != nil { - t.Fatalf("Failed to get current version: %v", err) - } - if initialVersion != 0 { - t.Fatalf("Expected initial version to be 0. but it was %v", initialVersion) +func (s *MigrateSuite) cleanupSampleMigrator(c *C) { + tables := []string{versionTable, "t1", "t2", "t3"} + for _, table := range tables { + s.Execute(c, "drop table if exists "+table) } } -func TestAppendMigration(t *testing.T) { - conn := mustConnect(t, defaultConnectionParameters) - clearMigrate(t, conn) - m := mustCreateMigrator(t, conn) +func (s *MigrateSuite) TestNewMigrator(c *C) { + var m *migrate.Migrator + var err error - name := "Update t" - sql := "update t set c=1" - m.AppendMigration(name, sql) + // Initial run + m, err = migrate.NewMigrator(s.conn, versionTable) + c.Assert(err, IsNil) - if len(m.Migrations) != 1 { - t.Fatal("Expected AppendMigration to add a migration but it didn't") - } - if m.Migrations[0].Name != name { - t.Fatalf("expected first migration Name to be %v, but it was %v", name, m.Migrations[0].Name) - } - if m.Migrations[0].SQL != sql { - t.Fatalf("expected first migration SQL to be %v, but it was %v", sql, m.Migrations[0].SQL) - } + // Creates version table + schemaVersionExists := s.tableExists(c, versionTable) + c.Assert(schemaVersionExists, Equals, true) + + // Succeeds when version table is already created + m, err = migrate.NewMigrator(s.conn, versionTable) + c.Assert(err, IsNil) + + initialVersion, err := m.GetCurrentVersion() + c.Assert(err, IsNil) + c.Assert(initialVersion, Equals, int32(0)) } -func TestPendingMigrations(t *testing.T) { - conn := mustConnect(t, defaultConnectionParameters) - clearMigrate(t, conn) - m := mustCreateMigrator(t, conn) +func (s *MigrateSuite) TestAppendMigration(c *C) { + m := s.createEmptyMigrator(c) - m.AppendMigration("update t", "update t set c=1") - m.AppendMigration("update z", "update z set c=1") + name := "Create t" + upSQL := "create t..." + downSQL := "drop t..." + m.AppendMigration(name, upSQL, downSQL) - mustExecute(t, conn, "update "+versionTable+" set version=1") - - pending, err := m.PendingMigrations() - if err != nil { - t.Fatalf("Unexpected error while getting pending migrations: %v", err) - } - if len(pending) != 1 { - t.Fatalf("Expected 1 pending migrations but there was %v", len(pending)) - } - if pending[0] != m.Migrations[1] { - t.Fatal("Did not include expected migration as pending") - } - - // Higher version than we know about - mustExecute(t, conn, "update "+versionTable+" set version=999") - _, err = m.PendingMigrations() - if _, ok := err.(migrate.BadVersionError); !ok { - t.Fatalf("Expected BadVersionError but received: %#v", err) - } - - // Lower version than is possible - mustExecute(t, conn, "update "+versionTable+" set version=-1") - _, err = m.PendingMigrations() - if _, ok := err.(migrate.BadVersionError); !ok { - t.Fatalf("Expected BadVersionError but received: %#v", err) - } + c.Assert(len(m.Migrations), Equals, 1) + c.Assert(m.Migrations[0].Name, Equals, name) + c.Assert(m.Migrations[0].UpSQL, Equals, upSQL) + c.Assert(m.Migrations[0].DownSQL, Equals, downSQL) } -func TestMigrate(t *testing.T) { - conn := mustConnect(t, defaultConnectionParameters) - clearMigrate(t, conn) - m := mustCreateMigrator(t, conn) +func (s *MigrateSuite) TestMigrate(c *C) { + m := s.createSampleMigrator(c) - m.AppendMigration("create t", "create table t(name text primary key)") + err := m.Migrate() + c.Assert(err, IsNil) + currentVersion := s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(3)) +} - if err := m.Migrate(); err != nil { - t.Fatalf("Unexpected error running Migrate: %v", err) +func (s *MigrateSuite) TestMigrateTo(c *C) { + m := s.createSampleMigrator(c) + + var onStartCallUpCount int + var onStartCallDownCount int + m.OnStart = func(_ *migrate.Migration, direction string) { + switch direction { + case "up": + onStartCallUpCount++ + case "down": + onStartCallDownCount++ + default: + c.Fatalf("Unexpected direction: %s", direction) + } } - if pending, err := m.PendingMigrations(); err != nil { - t.Fatalf("Unexpected error while getting pending migrations: %v", err) - } else if len(pending) != 0 { - t.Fatalf("Migrate did not do all migrations: %v pending", len(pending)) - } + // Migrate to -1 is error + err := m.MigrateTo(-1) + c.Assert(err, ErrorMatches, "schema_version version -1 is outside the valid versions of 0 to 3") - // Now test the OnStart callback and the Migrate when some are already done - var onStartCallCount int - m.OnStart = func(*migrate.Migration) { - onStartCallCount++ - } - m.AppendMigration("create t2", "create table t2(name text primary key)") + // Migrate past end is error + err = m.MigrateTo(int32(len(m.Migrations)) + 1) + c.Assert(err, ErrorMatches, "schema_version version 4 is outside the valid versions of 0 to 3") - if err := m.Migrate(); err != nil { - t.Fatalf("Unexpected error running Migrate: %v", err) - } + // Migrate from 0 up to 1 + err = m.MigrateTo(1) + c.Assert(err, IsNil) + currentVersion := s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(1)) + c.Assert(s.tableExists(c, "t1"), Equals, true) + c.Assert(s.tableExists(c, "t2"), Equals, false) + c.Assert(s.tableExists(c, "t3"), Equals, false) + c.Assert(onStartCallUpCount, Equals, 1) + c.Assert(onStartCallDownCount, Equals, 0) - if pending, err := m.PendingMigrations(); err != nil { - t.Fatalf("Unexpected error while getting pending migrations: %v", err) - } else if len(pending) != 0 { - t.Fatalf("Migrate did not do all migrations: %v pending", len(pending)) - } + // Migrate from 1 up to 3 + err = m.MigrateTo(3) + c.Assert(err, IsNil) + currentVersion = s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(3)) + c.Assert(s.tableExists(c, "t1"), Equals, true) + c.Assert(s.tableExists(c, "t2"), Equals, true) + c.Assert(s.tableExists(c, "t3"), Equals, true) + c.Assert(onStartCallUpCount, Equals, 3) + c.Assert(onStartCallDownCount, Equals, 0) - if onStartCallCount != 1 { - t.Fatalf("Expected OnStart to be called 1 time, but it was called %v times", onStartCallCount) - } + // Migrate from 3 to 3 is no-op + err = m.MigrateTo(3) + c.Assert(err, IsNil) + currentVersion = s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(3)) + c.Assert(s.tableExists(c, "t1"), Equals, true) + c.Assert(s.tableExists(c, "t2"), Equals, true) + c.Assert(s.tableExists(c, "t3"), Equals, true) + c.Assert(onStartCallUpCount, Equals, 3) + c.Assert(onStartCallDownCount, Equals, 0) + // Migrate from 3 down to 1 + err = m.MigrateTo(1) + c.Assert(err, IsNil) + currentVersion = s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(1)) + c.Assert(s.tableExists(c, "t1"), Equals, true) + c.Assert(s.tableExists(c, "t2"), Equals, false) + c.Assert(s.tableExists(c, "t3"), Equals, false) + c.Assert(onStartCallUpCount, Equals, 3) + c.Assert(onStartCallDownCount, Equals, 2) + + // Migrate from 1 down to 0 + err = m.MigrateTo(0) + c.Assert(err, IsNil) + currentVersion = s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(0)) + c.Assert(s.tableExists(c, "t1"), Equals, false) + c.Assert(s.tableExists(c, "t2"), Equals, false) + c.Assert(s.tableExists(c, "t3"), Equals, false) + c.Assert(onStartCallUpCount, Equals, 3) + c.Assert(onStartCallDownCount, Equals, 3) + + // Migrate back up to 3 + err = m.MigrateTo(3) + c.Assert(err, IsNil) + currentVersion = s.SelectValue(c, "select version from schema_version") + c.Assert(currentVersion, Equals, int32(3)) + c.Assert(s.tableExists(c, "t1"), Equals, true) + c.Assert(s.tableExists(c, "t2"), Equals, true) + c.Assert(s.tableExists(c, "t3"), Equals, true) + c.Assert(onStartCallUpCount, Equals, 6) + c.Assert(onStartCallDownCount, Equals, 3) } func Example_OnStartMigrationProgressLogging() { @@ -168,16 +222,16 @@ func Example_OnStartMigrationProgressLogging() { return } - m.OnStart = func(migration *migrate.Migration) { - fmt.Printf("Executing: %v", migration.Name) + m.OnStart = func(migration *migrate.Migration, direction string) { + fmt.Printf("Migrating %s: %s", direction, migration.Name) } - m.AppendMigration("create a table", "create temporary table foo(id serial primary key)") + m.AppendMigration("create a table", "create temporary table foo(id serial primary key)", "") if err = m.Migrate(); err != nil { fmt.Printf("Unexpected failure migrating: %v", err) return } // Output: - // Executing: create a table + // Migrating up: create a table }