diff --git a/migrator.go b/migrator.go index 3d2b032b0..64d803628 100644 --- a/migrator.go +++ b/migrator.go @@ -108,4 +108,8 @@ type Migrator interface { HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error GetIndexes(dst interface{}) ([]Index, error) + + // Locking + ObtainLock() error + ReleaseLock() error } diff --git a/migrator/migrator.go b/migrator/migrator.go index 64a5a4b52..9299429e5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -110,87 +110,105 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } -// AutoMigrate auto migrate values -func (m Migrator) AutoMigrate(values ...interface{}) error { - for _, value := range m.ReorderModels(values, true) { - queryTx := m.DB.Session(&gorm.Session{}) - execTx := queryTx - if m.DB.DryRun { - queryTx.DryRun = false - execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) +func (m Migrator) migrateTable(queryTx, execTx *gorm.DB, value interface{}) (err error) { + if err = execTx.Migrator().ObtainLock(); err != nil { + return + } + defer func() { + releaseErr := execTx.Migrator().ReleaseLock() + if err == nil { + err = releaseErr + } + }() + + if !queryTx.Migrator().HasTable(value) { + if err = execTx.Migrator().CreateTable(value); err != nil { + return err } - if !queryTx.Migrator().HasTable(value) { - if err := execTx.Migrator().CreateTable(value); err != nil { + } else { + if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + columnTypes, err := queryTx.Migrator().ColumnTypes(value) + if err != nil { return err } - } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - columnTypes, err := queryTx.Migrator().ColumnTypes(value) - if err != nil { - return err - } - var ( - parseIndexes = stmt.Schema.ParseIndexes() - parseCheckConstraints = stmt.Schema.ParseCheckConstraints() - ) - for _, dbName := range stmt.Schema.DBNames { - var foundColumn gorm.ColumnType - - for _, columnType := range columnTypes { - if columnType.Name() == dbName { - foundColumn = columnType - break - } - } + var ( + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() + ) + for _, dbName := range stmt.Schema.DBNames { + var foundColumn gorm.ColumnType - if foundColumn == nil { - // not found, add column - if err = execTx.Migrator().AddColumn(value, dbName); err != nil { - return err - } - } else { - // found, smartly migrate - field := stmt.Schema.FieldsByDBName[dbName] - if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { - return err - } + for _, columnType := range columnTypes { + if columnType.Name() == dbName { + foundColumn = columnType + break } } - if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { - for _, rel := range stmt.Schema.Relationships.Relations { - if rel.Field.IgnoreMigration { - continue - } - if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { - if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } + if foundColumn == nil { + // not found, add column + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + return err } } + } - for _, chk := range parseCheckConstraints { - if !queryTx.Migrator().HasConstraint(value, chk.Name) { - if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } + } - for _, idx := range parseIndexes { - if !queryTx.Migrator().HasIndex(value, idx.Name) { - if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { - return err - } + for _, chk := range parseCheckConstraints { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err } } + } - return nil - }); err != nil { - return err + for _, idx := range parseIndexes { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } } + + return nil + }); err != nil { + return err + } + } + return nil +} + +// AutoMigrate auto migrate values +func (m Migrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + + if err := m.migrateTable(queryTx, execTx, value); err != nil { + return err } } @@ -985,3 +1003,13 @@ func (m Migrator) GetTypeAliases(databaseTypeName string) []string { func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { return nil, errors.New("not support") } + +// ObtainLock obtains a global migration lock +func (m Migrator) ObtainLock() error { + return nil +} + +// ReleaseLock releases the global migration lock +func (m Migrator) ReleaseLock() error { + return nil +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index cfd3e0ace..b1aa27a12 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1685,3 +1685,29 @@ func TestTableType(t *testing.T) { t.Fatalf("expected comment %s got %s", tblComment, comment) } } + +func TestMigrateRaceCondition(t *testing.T) { + type TestTable struct { + gorm.Model + } + + for a := 0; a < 2; a++ { + t.Run("drop and migrate", func(t *testing.T) { + t.Run("drop", func(t *testing.T) { + if err := DB.Migrator().DropTable(&TestTable{}); err != nil { + t.Fatalf("failed to drop table: %v", err) + } + }) + + for i := 0; i < 2; i++ { + t.Run("migrate", func(t *testing.T) { + t.Parallel() + + if err := DB.AutoMigrate(&TestTable{}); err != nil { + t.Fatalf("failed to migrate: %v", err) + } + }) + } + }) + } +}