From 2e00b2bd7db0606c229edd4a159f9db5c2817a10 Mon Sep 17 00:00:00 2001 From: guilhermefbarbosa Date: Thu, 19 Oct 2023 09:45:04 -0300 Subject: [PATCH 1/3] feat: adds AfterError callback --- callbacks/callbacks.go | 6 ++++++ callbacks/error.go | 25 +++++++++++++++++++++++++ callbacks/interfaces.go | 4 ++++ schema/schema.go | 5 +++++ 4 files changed, 40 insertions(+) create mode 100644 callbacks/error.go diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef36..a2e53e276 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -42,6 +42,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:after_error", AfterError) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -49,6 +50,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:after_error", AfterError) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Clauses = config.QueryClauses @@ -58,6 +60,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete(config)) + deleteCallback.Register("gorm:after_error", AfterError) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Clauses = config.DeleteClauses @@ -68,6 +71,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:after_error", AfterError) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) @@ -75,9 +79,11 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { rowCallback := db.Callback().Row() rowCallback.Register("gorm:row", RowQuery) + rowCallback.Register("gorm:after_error", AfterError) rowCallback.Clauses = config.QueryClauses rawCallback := db.Callback().Raw() rawCallback.Register("gorm:raw", RawExec) + rawCallback.Register("gorm:after_error", AfterError) rawCallback.Clauses = config.QueryClauses } diff --git a/callbacks/error.go b/callbacks/error.go new file mode 100644 index 000000000..27734aad9 --- /dev/null +++ b/callbacks/error.go @@ -0,0 +1,25 @@ +package callbacks + +import ( + "gorm.io/gorm" + "reflect" +) + +// AfterError after error callback executes if any error happens during main callbacks +func AfterError(db *gorm.DB) { + if db.Statement.ReflectValue.Kind() == reflect.Ptr && db.Statement.ReflectValue.IsNil() { + return + } + if db.Error != nil && db.Statement.Schema != nil && !db.Statement.SkipHooks { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if db.Statement.Schema.AfterError { + if i, ok := value.(AfterErrorInterface); ok { + db.AddError(i.AfterError(tx)) + return true + } + } + return false + }) + } + return +} diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go index 2302470fc..baa23b066 100644 --- a/callbacks/interfaces.go +++ b/callbacks/interfaces.go @@ -37,3 +37,7 @@ type AfterDeleteInterface interface { type AfterFindInterface interface { AfterFind(*gorm.DB) error } + +type AfterErrorInterface interface { + AfterError(*gorm.DB) error +} diff --git a/schema/schema.go b/schema/schema.go index 3e7459ce7..3ca1dda2b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -25,6 +25,7 @@ const ( callbackTypeBeforeDelete callbackType = "BeforeDelete" callbackTypeAfterDelete callbackType = "AfterDelete" callbackTypeAfterFind callbackType = "AfterFind" + callbackTypeAfterError callbackType = "AfterError" ) // ErrUnsupportedDataType unsupported data type @@ -53,6 +54,7 @@ type Schema struct { BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool AfterFind bool + AfterError bool err error initialized chan struct{} namer Namer @@ -308,6 +310,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam callbackTypeBeforeSave, callbackTypeAfterSave, callbackTypeBeforeDelete, callbackTypeAfterDelete, callbackTypeAfterFind, + callbackTypeAfterError, } for _, cbName := range callbackTypes { if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { @@ -397,6 +400,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect return modelType.MethodByName(string(callbackTypeAfterDelete)) case callbackTypeAfterFind: return modelType.MethodByName(string(callbackTypeAfterFind)) + case callbackTypeAfterError: + return modelType.MethodByName(string(callbackTypeAfterError)) default: return reflect.ValueOf(nil) } From 352d9a9abbb319c06ecd7c30c1c92a9d1e1e9a7b Mon Sep 17 00:00:00 2001 From: guilhermefbarbosa Date: Thu, 19 Oct 2023 09:45:14 -0300 Subject: [PATCH 2/3] test: adds AfterError tests --- tests/hooks_test.go | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0753dd0b1..61c20f8ab 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -24,6 +24,7 @@ type Product struct { AfterSaveCallTimes int64 BeforeDeleteCallTimes int64 AfterDeleteCallTimes int64 + AfterErrorCallTimes int64 } func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { @@ -88,8 +89,16 @@ func (s *Product) AfterDelete(tx *gorm.DB) (err error) { return } +func (s *Product) AfterError(tx *gorm.DB) (err error) { + if s.Code == "after_error_error" { + err = errors.New("can't handle this error") + } + s.AfterErrorCallTimes = s.AfterErrorCallTimes + 1 + return +} + func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes, s.AfterErrorCallTimes} } func TestRunCallbacks(t *testing.T) { @@ -99,18 +108,18 @@ func TestRunCallbacks(t *testing.T) { p := Product{Code: "unique_code", Price: 100} DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0, 0}) { t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) } DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1, 0}) { t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) } p.Price = 200 DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1, 0}) { t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) } @@ -121,12 +130,12 @@ func TestRunCallbacks(t *testing.T) { } DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2, 0}) { t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) } DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 0}) { t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) } @@ -134,6 +143,10 @@ func TestRunCallbacks(t *testing.T) { t.Fatalf("Can't find a deleted record") } + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2, 1}) { + t.Fatalf("AfterError should be called because First raises error when doesn't fint, %v", p.GetCallTimes()) + } + beforeCallTimes := p.AfterFindCallTimes if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { t.Fatalf("Find don't raise error when record not found") @@ -142,6 +155,12 @@ func TestRunCallbacks(t *testing.T) { if p.AfterFindCallTimes != beforeCallTimes { t.Fatalf("AfterFind should not be called") } + + DB.Migrator().DropTable(&Product{}) + DB.Create(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{2, 3, 1, 1, 0, 0, 1, 1, 2, 2}) { + t.Fatalf("should call BeforeCreate, BeforeSave and AfterError, %v", p.GetCallTimes()) + } } func TestCallbacksWithErrors(t *testing.T) { @@ -208,6 +227,14 @@ func TestCallbacksWithErrors(t *testing.T) { if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") } + + DB.Migrator().DropTable(&Product{}) + err := DB.Create(&Product{ + Code: "after_error_error", + }).Error + if err == nil || !strings.Contains(err.Error(), "can't handle this error") { + t.Fatalf("error on AfterError should be appended to the previous error, but got %v", err) + } } type Product2 struct { From 9599c1e38e98e9ff0a0c08f0967d2b9c9d7d936a Mon Sep 17 00:00:00 2001 From: guilhermefbarbosa Date: Thu, 19 Oct 2023 10:33:28 -0300 Subject: [PATCH 3/3] chore: fix linter errors --- callbacks/error.go | 1 - 1 file changed, 1 deletion(-) diff --git a/callbacks/error.go b/callbacks/error.go index 27734aad9..5362a7805 100644 --- a/callbacks/error.go +++ b/callbacks/error.go @@ -21,5 +21,4 @@ func AfterError(db *gorm.DB) { return false }) } - return }