Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: distinguish between Unique and UniqueIndex #6386

Merged
merged 20 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 63 additions & 55 deletions migrator/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,20 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return
}

func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
a631807682 marked this conversation as resolved.
Show resolved Hide resolved
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}

// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
queryTx := m.DB.Session(&gorm.Session{})
execTx := queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
queryTx, execTx := m.GetQueryAndExecTx()
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err
Expand Down Expand Up @@ -268,14 +273,19 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
sql, vars := constraint.Build()
createTableSQL += sql + ","
values = append(values, vars...)
}
}
}
}

for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}

for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
Expand Down Expand Up @@ -439,6 +449,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error

// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}

// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
Expand Down Expand Up @@ -499,7 +513,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}

// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") {
// not primary key
if !field.PrimaryKey {
alterColumn = true
Expand Down Expand Up @@ -630,37 +644,36 @@ func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}

func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}

if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}

var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}

for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}

// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
//
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
a631807682 marked this conversation as resolved.
Show resolved Hide resolved
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
}
}

// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil {
return nil, nil, stmt.Table
return nil, stmt.Table
}

checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return nil, &chk, stmt.Table
return &chk, stmt.Table
}

uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
}

getTable := func(rel *schema.Relationship) string {
Expand All @@ -675,60 +688,57 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_

for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, nil, getTable(rel)
return constraint, getTable(rel)
}
}

if field := stmt.Schema.LookUpField(name); field != nil {
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return nil, &v, stmt.Table
return &v, stmt.Table
}
}

for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
}
}

for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, nil, getTable(rel)
return constraint, getTable(rel)
}
}
}

return nil, nil, stmt.Schema.Table
return nil, stmt.Schema.Table
}

// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if chk != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}

constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := buildConstraint(constraint)
sql, values := constraint.Build()
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}

return nil
})
}

// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
name = constraint.GetName()
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
})
Expand All @@ -739,11 +749,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
name = constraint.GetName()
}

return m.DB.Raw(
Expand Down
35 changes: 0 additions & 35 deletions schema/check.go

This file was deleted.

66 changes: 66 additions & 0 deletions schema/constraint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package schema

import (
"regexp"
"strings"

"gorm.io/gorm/clause"
)

// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")

type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}

func (chk *CheckConstraint) GetName() string { return chk.Name }

func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}

// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

type UniqueConstraint struct {
Name string
Field *Field
}

func (uni *UniqueConstraint) GetName() string { return uni.Name }

func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}

// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}
31 changes: 30 additions & 1 deletion schema/check_test.go → schema/constraint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)

type UserCheck struct {
Expand All @@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err)
}

results := map[string]schema.Check{
results := map[string]schema.CheckConstraint{
"name_checker": {
Name: "name_checker",
Constraint: "name <> 'jinzhu'",
Expand Down Expand Up @@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
}
}
}

func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}

user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()

results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}
6 changes: 6 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ type Field struct {
Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool

// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}

func (field *Field) BindName() string {
Expand Down
Loading
Loading