diff --git a/migrate/migrate.go b/migrate/migrate.go index 81f4858e..636d2f04 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -11,6 +11,14 @@ func (e BadVersionError) Error() string { return string(e) } +type IrreversibleMigrationError struct { + m *Migration +} + +func (e IrreversibleMigrationError) Error() string { + return fmt.Sprintf("Irreversible migration: %d - %s", e.m.Sequence, e.m.Name) +} + type Migration struct { Sequence int32 Name string @@ -88,6 +96,9 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) { sequence = current.Sequence - 1 sql = current.DownSQL directionName = "down" + if current.DownSQL == "" { + return IrreversibleMigrationError{m: current} + } } var innerErr error diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 5616eb3e..222e90aa 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -111,7 +111,7 @@ func (s *MigrateSuite) TestMigrate(c *C) { c.Assert(currentVersion, Equals, int32(3)) } -func (s *MigrateSuite) TestMigrateTo(c *C) { +func (s *MigrateSuite) TestMigrateToLifeCycle(c *C) { m := s.createSampleMigrator(c) var onStartCallUpCount int @@ -127,16 +127,8 @@ func (s *MigrateSuite) TestMigrateTo(c *C) { } } - // Migrate to -1 is error - err := m.MigrateTo(-1) - c.Assert(err, ErrorMatches, "schema_version version -1 is outside the valid versions of 0 to 3") - - // Migrate past end is error - err = m.MigrateTo(int32(len(m.Migrations)) + 1) - c.Assert(err, ErrorMatches, "schema_version version 4 is outside the valid versions of 0 to 3") - // Migrate from 0 up to 1 - err = m.MigrateTo(1) + err := m.MigrateTo(1) c.Assert(err, IsNil) currentVersion := s.SelectValue(c, "select version from schema_version") c.Assert(currentVersion, Equals, int32(1)) @@ -201,6 +193,28 @@ func (s *MigrateSuite) TestMigrateTo(c *C) { c.Assert(onStartCallUpCount, Equals, 6) c.Assert(onStartCallDownCount, Equals, 3) } +func (s *MigrateSuite) TestMigrateToBoundaries(c *C) { + m := s.createSampleMigrator(c) + + // Migrate to -1 is error + err := m.MigrateTo(-1) + c.Assert(err, ErrorMatches, "schema_version version -1 is outside the valid versions of 0 to 3") + + // Migrate past end is error + err = m.MigrateTo(int32(len(m.Migrations)) + 1) + c.Assert(err, ErrorMatches, "schema_version version 4 is outside the valid versions of 0 to 3") +} + +func (s *MigrateSuite) TestMigrateToIrreversible(c *C) { + m := s.createEmptyMigrator(c) + m.AppendMigration("Foo", "drop table if exists t3", "") + + err := m.MigrateTo(1) + c.Assert(err, IsNil) + + err = m.MigrateTo(0) + c.Assert(err, ErrorMatches, "Irreversible migration: 1 - Foo") +} func Example_OnStartMigrationProgressLogging() { conn, err := pgx.Connect(*defaultConnectionParameters)