Skip to content

Commit

Permalink
fix: remove callback from callbacks if Remove() called (#6916)
Browse files Browse the repository at this point in the history
* fix: remove callback from callbacks if Remove() called

* reduce number of loops

* remove unnecessary blank line
  • Loading branch information
snackmgmg authored Mar 26, 2024
1 parent 956f7ce commit 26195e6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
19 changes: 19 additions & 0 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {

func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}

if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks

Expand Down Expand Up @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {

return
}

func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}
48 changes: 47 additions & 1 deletion tests/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
results: []string{"c1", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
Expand Down Expand Up @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

func TestCallbacksGet(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()

createCallback.Before("*").Register("c1", c1)
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
}

createCallback.Remove("c1")
if cb := createCallback.Get("c2"); cb != nil {
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
}
}

func TestCallbacksRemove(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()

createCallback.Before("*").Register("c1", c1)
createCallback.After("*").Register("c2", c2)
createCallback.Before("c4").Register("c3", c3)
createCallback.After("c2").Register("c4", c4)

// callbacks: []string{"c1", "c3", "c4", "c2"}
createCallback.Remove("c1")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c4")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c2")
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}

createCallback.Remove("c3")
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

0 comments on commit 26195e6

Please sign in to comment.