Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds AfterError callback hook #6649

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:after_error", AfterError)
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses

queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:after_error", AfterError)

Check failure on line 53 in callbacks/callbacks.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 Error return value of `queryCallback.Register` is not checked (errcheck) Raw Output: callbacks/callbacks.go:53:24: Error return value of `queryCallback.Register` is not checked (errcheck) queryCallback.Register("gorm:after_error", AfterError) ^
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
Expand All @@ -58,6 +60,7 @@
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_error", AfterError)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
Expand All @@ -68,16 +71,19 @@
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:after_error", AfterError)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses

rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Register("gorm:after_error", AfterError)

Check failure on line 82 in callbacks/callbacks.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 Error return value of `rowCallback.Register` is not checked (errcheck) Raw Output: callbacks/callbacks.go:82:22: Error return value of `rowCallback.Register` is not checked (errcheck) rowCallback.Register("gorm:after_error", AfterError) ^
rowCallback.Clauses = config.QueryClauses

rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Register("gorm:after_error", AfterError)

Check failure on line 87 in callbacks/callbacks.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 Error return value of `rawCallback.Register` is not checked (errcheck) Raw Output: callbacks/callbacks.go:87:22: Error return value of `rawCallback.Register` is not checked (errcheck) rawCallback.Register("gorm:after_error", AfterError) ^
rawCallback.Clauses = config.QueryClauses
}
24 changes: 24 additions & 0 deletions callbacks/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package callbacks

import (
"gorm.io/gorm"
"reflect"
)

// AfterError after error callback executes if any error happens during main callbacks
func AfterError(db *gorm.DB) {
if db.Statement.ReflectValue.Kind() == reflect.Ptr && db.Statement.ReflectValue.IsNil() {
return
}
if db.Error != nil && db.Statement.Schema != nil && !db.Statement.SkipHooks {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if db.Statement.Schema.AfterError {
if i, ok := value.(AfterErrorInterface); ok {
db.AddError(i.AfterError(tx))
return true
}
}
return false
})
}
}
4 changes: 4 additions & 0 deletions callbacks/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ type AfterDeleteInterface interface {
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

type AfterErrorInterface interface {
AfterError(*gorm.DB) error
}
5 changes: 5 additions & 0 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
callbackTypeAfterError callbackType = "AfterError"
)

// ErrUnsupportedDataType unsupported data type
Expand Down Expand Up @@ -53,6 +54,7 @@ type Schema struct {
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
AfterError bool
err error
initialized chan struct{}
namer Namer
Expand Down Expand Up @@ -308,6 +310,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
callbackTypeAfterError,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
Expand Down Expand Up @@ -397,6 +400,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
case callbackTypeAfterError:
return modelType.MethodByName(string(callbackTypeAfterError))
default:
return reflect.ValueOf(nil)
}
Expand Down
39 changes: 33 additions & 6 deletions tests/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Product struct {
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
AfterErrorCallTimes int64
}

func (s *Product) BeforeCreate(tx *gorm.DB) (err error) {
Expand Down Expand Up @@ -88,8 +89,16 @@ func (s *Product) AfterDelete(tx *gorm.DB) (err error) {
return
}

func (s *Product) AfterError(tx *gorm.DB) (err error) {
if s.Code == "after_error_error" {
err = errors.New("can't handle this error")
}
s.AfterErrorCallTimes = s.AfterErrorCallTimes + 1
return
}

func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes, s.AfterErrorCallTimes}
}

func TestRunCallbacks(t *testing.T) {
Expand All @@ -99,18 +108,18 @@ func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)

if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0, 0}) {
t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}

DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1, 0}) {
t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes())
}

p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1, 0}) {
t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}

Expand All @@ -121,19 +130,23 @@ func TestRunCallbacks(t *testing.T) {
}

DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2, 0}) {
t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes())
}

DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 0}) {
t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}

if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Fatalf("Can't find a deleted record")
}

if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 1}) {
t.Fatalf("AfterError should be called because First raises error when doesn't fint, %v", p.GetCallTimes())
}

beforeCallTimes := p.AfterFindCallTimes
if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil {
t.Fatalf("Find don't raise error when record not found")
Expand All @@ -142,6 +155,12 @@ func TestRunCallbacks(t *testing.T) {
if p.AfterFindCallTimes != beforeCallTimes {
t.Fatalf("AfterFind should not be called")
}

DB.Migrator().DropTable(&Product{})
DB.Create(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{2, 3, 1, 1, 0, 0, 1, 1, 2, 2}) {
t.Fatalf("should call BeforeCreate, BeforeSave and AfterError, %v", p.GetCallTimes())
}
}

func TestCallbacksWithErrors(t *testing.T) {
Expand Down Expand Up @@ -208,6 +227,14 @@ func TestCallbacksWithErrors(t *testing.T) {
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback")
}

DB.Migrator().DropTable(&Product{})
err := DB.Create(&Product{
Code: "after_error_error",
}).Error
if err == nil || !strings.Contains(err.Error(), "can't handle this error") {
t.Fatalf("error on AfterError should be appended to the previous error, but got %v", err)
}
}

type Product2 struct {
Expand Down
Loading