diff --git a/migrate/connection_settings_test.go.example b/migrate/connection_settings_test.go.example deleted file mode 100644 index a9acac72..00000000 --- a/migrate/connection_settings_test.go.example +++ /dev/null @@ -1,7 +0,0 @@ -package migrate_test - -import ( - "github.com/JackC/pgx" -) - -var defaultConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/migrate/migrate.go b/migrate/migrate.go deleted file mode 100644 index eda9c4c4..00000000 --- a/migrate/migrate.go +++ /dev/null @@ -1,194 +0,0 @@ -package migrate - -import ( - "fmt" - "github.com/JackC/pgx" - "io/ioutil" - "path/filepath" - "strings" -) - -type BadVersionError string - -func (e BadVersionError) Error() string { - return string(e) -} - -type IrreversibleMigrationError struct { - m *Migration -} - -func (e IrreversibleMigrationError) Error() string { - return fmt.Sprintf("Irreversible migration: %d - %s", e.m.Sequence, e.m.Name) -} - -type NoMigrationsFoundError struct { - Path string -} - -func (e NoMigrationsFoundError) Error() string { - return fmt.Sprintf("No migrations found at %s", e.Path) -} - -type Migration struct { - Sequence int32 - Name string - UpSQL string - DownSQL string -} - -type Migrator struct { - conn *pgx.Connection - versionTable string - Migrations []*Migration - OnStart func(*Migration, string) // OnStart is called when a migration is run with the migration and direction -} - -func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err error) { - m = &Migrator{conn: conn, versionTable: versionTable} - err = m.ensureSchemaVersionTableExists() - m.Migrations = make([]*Migration, 0) - return -} - -func (m *Migrator) LoadMigrations(path string) error { - paths, err := filepath.Glob(filepath.Join(path, "*.sql")) - if err != nil { - return err - } - if len(paths) == 0 { - return NoMigrationsFoundError{Path: path} - } - - for _, p := range paths { - body, err := ioutil.ReadFile(p) - if err != nil { - return err - } - - pieces := strings.SplitN(string(body), "---- create above / drop below ----", 2) - var upSQL, downSQL string - upSQL = strings.TrimSpace(pieces[0]) - if len(pieces) == 2 { - downSQL = strings.TrimSpace(pieces[1]) - } - m.AppendMigration(filepath.Base(p), upSQL, downSQL) - } - - return nil -} - -func (m *Migrator) AppendMigration(name, upSQL, downSQL string) { - m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations)) + 1, Name: name, UpSQL: upSQL, DownSQL: downSQL}) - return -} - -// Migrate runs pending migrations -// It calls m.OnStart when it begins a migration -func (m *Migrator) Migrate() error { - return m.MigrateTo(int32(len(m.Migrations))) -} - -// MigrateTo migrates to targetVersion -func (m *Migrator) MigrateTo(targetVersion int32) (err error) { - // Lock to ensure multiple migrations cannot occur simultaneously - lockNum := int64(9628173550095224) // arbitrary random number - if _, lockErr := m.conn.Execute("select pg_advisory_lock($1)", lockNum); lockErr != nil { - return lockErr - } - defer func() { - _, unlockErr := m.conn.Execute("select pg_advisory_unlock($1)", lockNum) - if err == nil && unlockErr != nil { - err = unlockErr - } - }() - - currentVersion, err := m.GetCurrentVersion() - if err != nil { - return err - } - - if targetVersion < 0 || int32(len(m.Migrations)) < targetVersion { - errMsg := fmt.Sprintf("%s version %d is outside the valid versions of 0 to %d", m.versionTable, targetVersion, len(m.Migrations)) - return BadVersionError(errMsg) - } - - var direction int32 - if currentVersion < targetVersion { - direction = 1 - } else { - direction = -1 - } - - for currentVersion != targetVersion { - var current *Migration - var sql, directionName string - var sequence int32 - if direction == 1 { - current = m.Migrations[currentVersion] - sequence = current.Sequence - sql = current.UpSQL - directionName = "up" - } else { - current = m.Migrations[currentVersion-1] - sequence = current.Sequence - 1 - sql = current.DownSQL - directionName = "down" - if current.DownSQL == "" { - return IrreversibleMigrationError{m: current} - } - } - - var innerErr error - _, txErr := m.conn.Transaction(func() bool { - - // Fire on start callback - if m.OnStart != nil { - m.OnStart(current, directionName) - } - - // Execute the migration - if _, innerErr = m.conn.Execute(sql); innerErr != nil { - return false - } - - // Add one to the version - if _, innerErr = m.conn.Execute("update "+m.versionTable+" set version=$1", sequence); innerErr != nil { - return false - } - - // A migration was completed successfully, return true to commit the transaction - return true - }) - - if txErr != nil { - return txErr - } - if innerErr != nil { - return innerErr - } - - currentVersion = currentVersion + direction - } - - return nil -} - -func (m *Migrator) GetCurrentVersion() (int32, error) { - if v, err := m.conn.SelectValue("select version from " + m.versionTable); err == nil { - return v.(int32), nil - } else { - return 0, err - } -} - -func (m *Migrator) ensureSchemaVersionTableExists() (err error) { - _, err = m.conn.Execute(` - create table if not exists schema_version(version int4 not null); - - insert into schema_version(version) - select 0 - where 0=(select count(*) from schema_version); - `) - return -} diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go deleted file mode 100644 index 6113b485..00000000 --- a/migrate/migrate_test.go +++ /dev/null @@ -1,287 +0,0 @@ -package migrate_test - -import ( - "fmt" - "github.com/JackC/pgx" - "github.com/JackC/pgx/migrate" - . "gopkg.in/check.v1" - "testing" -) - -type MigrateSuite struct { - conn *pgx.Connection -} - -func Test(t *testing.T) { TestingT(t) } - -var _ = Suite(&MigrateSuite{}) - -var versionTable string = "schema_version" - -func (s *MigrateSuite) SetUpTest(c *C) { - var err error - s.conn, err = pgx.Connect(*defaultConnectionParameters) - c.Assert(err, IsNil) - - s.cleanupSampleMigrator(c) -} - -func (s *MigrateSuite) SelectValue(c *C, sql string, arguments ...interface{}) interface{} { - value, err := s.conn.SelectValue(sql, arguments...) - c.Assert(err, IsNil) - return value -} - -func (s *MigrateSuite) Execute(c *C, sql string, arguments ...interface{}) string { - commandTag, err := s.conn.Execute(sql, arguments...) - c.Assert(err, IsNil) - return commandTag -} - -func (s *MigrateSuite) tableExists(c *C, tableName string) bool { - return s.SelectValue(c, - "select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)", - defaultConnectionParameters.Database, - tableName).(bool) -} - -func (s *MigrateSuite) createEmptyMigrator(c *C) *migrate.Migrator { - var err error - m, err := migrate.NewMigrator(s.conn, versionTable) - c.Assert(err, IsNil) - return m -} - -func (s *MigrateSuite) createSampleMigrator(c *C) *migrate.Migrator { - m := s.createEmptyMigrator(c) - m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;") - m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;") - m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;") - return m -} - -func (s *MigrateSuite) cleanupSampleMigrator(c *C) { - tables := []string{versionTable, "t1", "t2", "t3"} - for _, table := range tables { - s.Execute(c, "drop table if exists "+table) - } -} - -func (s *MigrateSuite) TestNewMigrator(c *C) { - var m *migrate.Migrator - var err error - - // Initial run - m, err = migrate.NewMigrator(s.conn, versionTable) - c.Assert(err, IsNil) - - // Creates version table - schemaVersionExists := s.tableExists(c, versionTable) - c.Assert(schemaVersionExists, Equals, true) - - // Succeeds when version table is already created - m, err = migrate.NewMigrator(s.conn, versionTable) - c.Assert(err, IsNil) - - initialVersion, err := m.GetCurrentVersion() - c.Assert(err, IsNil) - c.Assert(initialVersion, Equals, int32(0)) -} - -func (s *MigrateSuite) TestAppendMigration(c *C) { - m := s.createEmptyMigrator(c) - - name := "Create t" - upSQL := "create t..." - downSQL := "drop t..." - m.AppendMigration(name, upSQL, downSQL) - - c.Assert(len(m.Migrations), Equals, 1) - c.Assert(m.Migrations[0].Name, Equals, name) - c.Assert(m.Migrations[0].UpSQL, Equals, upSQL) - c.Assert(m.Migrations[0].DownSQL, Equals, downSQL) -} - -func (s *MigrateSuite) TestLoadMigrationsMissingDirectory(c *C) { - m := s.createEmptyMigrator(c) - err := m.LoadMigrations("testdata/missing") - c.Assert(err, ErrorMatches, "No migrations found at testdata/missing") -} - -func (s *MigrateSuite) TestLoadMigrationsEmptyDirectory(c *C) { - m := s.createEmptyMigrator(c) - err := m.LoadMigrations("testdata/empty") - c.Assert(err, ErrorMatches, "No migrations found at testdata/empty") -} - -func (s *MigrateSuite) TestLoadMigrations(c *C) { - m := s.createEmptyMigrator(c) - err := m.LoadMigrations("testdata/sample") - c.Assert(err, IsNil) - c.Assert(m.Migrations, HasLen, 3) - - c.Check(m.Migrations[0].Name, Equals, "001_create_t1.sql") - c.Check(m.Migrations[0].UpSQL, Equals, `create table t1( - id serial primary key -);`) - c.Check(m.Migrations[0].DownSQL, Equals, "drop table t1;") - - c.Check(m.Migrations[1].Name, Equals, "002_create_t2.sql") - c.Check(m.Migrations[1].UpSQL, Equals, `create table t2( - id serial primary key -);`) - c.Check(m.Migrations[1].DownSQL, Equals, "drop table t2;") - - c.Check(m.Migrations[2].Name, Equals, "003_irreversible.sql") - c.Check(m.Migrations[2].UpSQL, Equals, "drop table t2;") - c.Check(m.Migrations[2].DownSQL, Equals, "") -} - -func (s *MigrateSuite) TestMigrate(c *C) { - m := s.createSampleMigrator(c) - - err := m.Migrate() - c.Assert(err, IsNil) - currentVersion := s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(3)) -} - -func (s *MigrateSuite) TestMigrateToLifeCycle(c *C) { - m := s.createSampleMigrator(c) - - var onStartCallUpCount int - var onStartCallDownCount int - m.OnStart = func(_ *migrate.Migration, direction string) { - switch direction { - case "up": - onStartCallUpCount++ - case "down": - onStartCallDownCount++ - default: - c.Fatalf("Unexpected direction: %s", direction) - } - } - - // Migrate from 0 up to 1 - err := m.MigrateTo(1) - c.Assert(err, IsNil) - currentVersion := s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(1)) - c.Assert(s.tableExists(c, "t1"), Equals, true) - c.Assert(s.tableExists(c, "t2"), Equals, false) - c.Assert(s.tableExists(c, "t3"), Equals, false) - c.Assert(onStartCallUpCount, Equals, 1) - c.Assert(onStartCallDownCount, Equals, 0) - - // Migrate from 1 up to 3 - err = m.MigrateTo(3) - c.Assert(err, IsNil) - currentVersion = s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(3)) - c.Assert(s.tableExists(c, "t1"), Equals, true) - c.Assert(s.tableExists(c, "t2"), Equals, true) - c.Assert(s.tableExists(c, "t3"), Equals, true) - c.Assert(onStartCallUpCount, Equals, 3) - c.Assert(onStartCallDownCount, Equals, 0) - - // Migrate from 3 to 3 is no-op - err = m.MigrateTo(3) - c.Assert(err, IsNil) - currentVersion = s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(3)) - c.Assert(s.tableExists(c, "t1"), Equals, true) - c.Assert(s.tableExists(c, "t2"), Equals, true) - c.Assert(s.tableExists(c, "t3"), Equals, true) - c.Assert(onStartCallUpCount, Equals, 3) - c.Assert(onStartCallDownCount, Equals, 0) - - // Migrate from 3 down to 1 - err = m.MigrateTo(1) - c.Assert(err, IsNil) - currentVersion = s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(1)) - c.Assert(s.tableExists(c, "t1"), Equals, true) - c.Assert(s.tableExists(c, "t2"), Equals, false) - c.Assert(s.tableExists(c, "t3"), Equals, false) - c.Assert(onStartCallUpCount, Equals, 3) - c.Assert(onStartCallDownCount, Equals, 2) - - // Migrate from 1 down to 0 - err = m.MigrateTo(0) - c.Assert(err, IsNil) - currentVersion = s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(0)) - c.Assert(s.tableExists(c, "t1"), Equals, false) - c.Assert(s.tableExists(c, "t2"), Equals, false) - c.Assert(s.tableExists(c, "t3"), Equals, false) - c.Assert(onStartCallUpCount, Equals, 3) - c.Assert(onStartCallDownCount, Equals, 3) - - // Migrate back up to 3 - err = m.MigrateTo(3) - c.Assert(err, IsNil) - currentVersion = s.SelectValue(c, "select version from schema_version") - c.Assert(currentVersion, Equals, int32(3)) - c.Assert(s.tableExists(c, "t1"), Equals, true) - c.Assert(s.tableExists(c, "t2"), Equals, true) - c.Assert(s.tableExists(c, "t3"), Equals, true) - c.Assert(onStartCallUpCount, Equals, 6) - c.Assert(onStartCallDownCount, Equals, 3) -} - -func (s *MigrateSuite) TestMigrateToBoundaries(c *C) { - m := s.createSampleMigrator(c) - - // Migrate to -1 is error - err := m.MigrateTo(-1) - c.Assert(err, ErrorMatches, "schema_version version -1 is outside the valid versions of 0 to 3") - - // Migrate past end is error - err = m.MigrateTo(int32(len(m.Migrations)) + 1) - c.Assert(err, ErrorMatches, "schema_version version 4 is outside the valid versions of 0 to 3") -} - -func (s *MigrateSuite) TestMigrateToIrreversible(c *C) { - m := s.createEmptyMigrator(c) - m.AppendMigration("Foo", "drop table if exists t3", "") - - err := m.MigrateTo(1) - c.Assert(err, IsNil) - - err = m.MigrateTo(0) - c.Assert(err, ErrorMatches, "Irreversible migration: 1 - Foo") -} - -func Example_OnStartMigrationProgressLogging() { - conn, err := pgx.Connect(*defaultConnectionParameters) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - // Clear any previous runs - if _, err = conn.Execute("drop table if exists schema_version"); err != nil { - fmt.Printf("Unable to drop schema_version table: %v", err) - return - } - - var m *migrate.Migrator - m, err = migrate.NewMigrator(conn, "schema_version") - if err != nil { - fmt.Printf("Unable to create migrator: %v", err) - return - } - - m.OnStart = func(migration *migrate.Migration, direction string) { - fmt.Printf("Migrating %s: %s", direction, migration.Name) - } - - m.AppendMigration("create a table", "create temporary table foo(id serial primary key)", "") - - if err = m.Migrate(); err != nil { - fmt.Printf("Unexpected failure migrating: %v", err) - return - } - // Output: - // Migrating up: create a table -} diff --git a/migrate/testdata/sample/001_create_t1.sql b/migrate/testdata/sample/001_create_t1.sql deleted file mode 100644 index 87a05c19..00000000 --- a/migrate/testdata/sample/001_create_t1.sql +++ /dev/null @@ -1,7 +0,0 @@ -create table t1( - id serial primary key -); - ----- create above / drop below ---- - -drop table t1; diff --git a/migrate/testdata/sample/002_create_t2.sql b/migrate/testdata/sample/002_create_t2.sql deleted file mode 100644 index 352e5149..00000000 --- a/migrate/testdata/sample/002_create_t2.sql +++ /dev/null @@ -1,7 +0,0 @@ -create table t2( - id serial primary key -); - ----- create above / drop below ---- - -drop table t2; diff --git a/migrate/testdata/sample/003_irreversible.sql b/migrate/testdata/sample/003_irreversible.sql deleted file mode 100644 index fcb15c77..00000000 --- a/migrate/testdata/sample/003_irreversible.sql +++ /dev/null @@ -1 +0,0 @@ -drop table t2;