Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map insert support return increment id #6662

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) {
}

db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil &&
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
if db.RowsAffected == 0 {
return
}

var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}

// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}

Expand All @@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) {
break
}

_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
Expand All @@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) {
break
}

if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ const (
Bytes DataType = "bytes"
)

const DefaultAutoIncrementIncrement int64 = 1

// Field is the representation of model schema's field
type Field struct {
Name string
Expand Down Expand Up @@ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1,
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
}

for field.IndirectFieldType.Kind() == reflect.Ptr {
Expand Down
180 changes: 179 additions & 1 deletion tests/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tests_test

import (
"errors"
"fmt"
"regexp"
"testing"
"time"
Expand Down Expand Up @@ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
}
}

func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
type OnConfilctUser struct {
ID string
Name string `gorm:"default:null"`
Expand Down Expand Up @@ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
AssertEqual(t, u2.Email, "on-confilct-user-email-2")
AssertEqual(t, u2.Mobile, "133xxxx")
}

func TestCreateFromMapWithoutPK(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supportting for mysql")
}

// case 1: one record, create from map[string]interface{}
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}

if _, ok := mapValue1["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}

var result1 User
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}

var idVal int64
_, ok := mapValue1["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}

idVal, ok = mapValue1["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}

if int64(result1.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}

// case2: one record, create from *map[string]interface{}
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}

if _, ok := mapValue2["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}

var result2 User
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}

_, ok = mapValue2["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}

idVal, ok = mapValue2["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}

if int64(result2.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}

// case 3: records
values := []map[string]interface{}{
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
}

beforeLen := len(values)
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}

// mariadb with returning, values will be appended with id map
if len(values) == beforeLen*2 {
t.Skipf("This test case skipped, because the db supports returning")
}

for i := range values {
v, ok := values[i]["id"]
if !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}

var result User
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(result.ID) != v.(int64) {
t.Fatal("failed to create data from map with table, @id != id")
}
}
}

func TestCreateFromMapWithTable(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supportting for mysql")
}
tableDB := DB.Table("`users`")

// case 1: create from map[string]interface{}
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err)
}

if _, ok := record["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

var res map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}

if int64(res["id"].(uint64)) != record["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}

// case 2: create from *map[string]interface{}
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
tableDB2 := DB.Table("users")
if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := record1["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

var res1 map[string]interface{}
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}

if int64(res1["id"].(uint64)) != record1["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}

// case 3: create from []map[string]interface{}
records := []map[string]interface{}{
{"name": "create_from_map_with_table_2", "age": 19},
{"name": "create_from_map_with_table_3", "age": 20},
}

tableDB = DB.Table("users")
if err := tableDB.Create(&records).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}

if _, ok := records[0]["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

if _, ok := records[1]["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}

var res2 map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}

var res3 map[string]interface{}
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}

if int64(res2["id"].(uint64)) != records[0]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}

if int64(res3["id"].(uint64)) != records[1]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}
}
Loading