diff --git a/neo4j/bookmarks.go b/neo4j/bookmarks.go index 4a3e8ee6..9bd111fd 100644 --- a/neo4j/bookmarks.go +++ b/neo4j/bookmarks.go @@ -20,6 +20,7 @@ package neo4j import ( + "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collection" "sync" ) @@ -30,23 +31,25 @@ import ( // from raw values and BookmarksToRawValues for accessing the raw values. type Bookmarks = []string +// BookmarkManager centralizes bookmark manager supply and notification +// This API is experimental and may be changed or removed without prior notice type BookmarkManager interface { // UpdateBookmarks updates the bookmark for the specified database // previousBookmarks are the initial bookmarks of the bookmark holder (like a Session) // newBookmarks are the bookmarks that are received after completion of the bookmark holder operation (like the end of a Session) - UpdateBookmarks(database string, previousBookmarks, newBookmarks Bookmarks) + UpdateBookmarks(ctx context.Context, database string, previousBookmarks, newBookmarks Bookmarks) error // GetAllBookmarks returns all the bookmarks tracked by this bookmark manager // Note: the order of the returned bookmark slice is not guaranteed - GetAllBookmarks() Bookmarks + GetAllBookmarks(ctx context.Context) (Bookmarks, error) // GetBookmarks returns all the bookmarks associated with the specified database // Note: the order of the returned bookmark slice does not need to be deterministic - GetBookmarks(database string) Bookmarks + GetBookmarks(ctx context.Context, database string) (Bookmarks, error) // Forget removes all databases' bookmarks // Note: it is the driver user's responsibility to call this - Forget(databases ...string) + Forget(ctx context.Context, databases ...string) error } // BookmarkManagerConfig is an experimental API and may be changed or removed @@ -61,26 +64,26 @@ type BookmarkManagerConfig struct { // Hook called whenever bookmarks for a given database get updated // The hook is called with the database and the new bookmarks // Note: the order of the supplied bookmark slice is not guaranteed - BookmarkUpdateNotifier func(string, Bookmarks) + BookmarkConsumer func(ctx context.Context, database string, bookmarks Bookmarks) error } type BookmarkSupplier interface { // GetAllBookmarks returns all known bookmarks to the bookmark manager - GetAllBookmarks() Bookmarks + GetAllBookmarks(ctx context.Context) (Bookmarks, error) // GetBookmarks returns all the bookmarks of the specified database to the bookmark manager - GetBookmarks(database string) Bookmarks + GetBookmarks(ctx context.Context, database string) (Bookmarks, error) } type bookmarkManager struct { - bookmarks *sync.Map - supplier BookmarkSupplier - notifyUpdatesFn func(string, Bookmarks) + bookmarks *sync.Map + supplier BookmarkSupplier + consumer func(context.Context, string, Bookmarks) error } -func (b *bookmarkManager) UpdateBookmarks(database string, previousBookmarks, newBookmarks Bookmarks) { +func (b *bookmarkManager) UpdateBookmarks(ctx context.Context, database string, previousBookmarks, newBookmarks Bookmarks) error { if len(newBookmarks) == 0 { - return + return nil } var bookmarksToNotify Bookmarks storedNewBookmarks := collection.NewSet(newBookmarks) @@ -92,52 +95,62 @@ func (b *bookmarkManager) UpdateBookmarks(database string, previousBookmarks, ne currentBookmarks.AddAll(newBookmarks) bookmarksToNotify = currentBookmarks.Values() } - if b.notifyUpdatesFn != nil { - b.notifyUpdatesFn(database, bookmarksToNotify) + if b.consumer != nil { + return b.consumer(ctx, database, bookmarksToNotify) } + return nil } -func (b *bookmarkManager) GetAllBookmarks() Bookmarks { +func (b *bookmarkManager) GetAllBookmarks(ctx context.Context) (Bookmarks, error) { allBookmarks := collection.NewSet([]string{}) if b.supplier != nil { - allBookmarks.AddAll(b.supplier.GetAllBookmarks()) + bookmarks, err := b.supplier.GetAllBookmarks(ctx) + if err != nil { + return nil, err + } + allBookmarks.AddAll(bookmarks) } b.bookmarks.Range(func(db, rawBookmarks any) bool { bookmarks := rawBookmarks.(collection.Set[string]) allBookmarks.Union(bookmarks) return true }) - return allBookmarks.Values() + return allBookmarks.Values(), nil } -func (b *bookmarkManager) GetBookmarks(database string) Bookmarks { +func (b *bookmarkManager) GetBookmarks(ctx context.Context, database string) (Bookmarks, error) { var extraBookmarks Bookmarks if b.supplier != nil { - extraBookmarks = b.supplier.GetBookmarks(database) + bookmarks, err := b.supplier.GetBookmarks(ctx, database) + if err != nil { + return nil, err + } + extraBookmarks = bookmarks } rawBookmarks, found := b.bookmarks.Load(database) if !found { - return extraBookmarks + return extraBookmarks, nil } bookmarks := rawBookmarks.(collection.Set[string]).Copy() if extraBookmarks == nil { - return bookmarks.Values() + return bookmarks.Values(), nil } bookmarks.AddAll(extraBookmarks) - return bookmarks.Values() + return bookmarks.Values(), nil } -func (b *bookmarkManager) Forget(databases ...string) { +func (b *bookmarkManager) Forget(ctx context.Context, databases ...string) error { for _, db := range databases { b.bookmarks.Delete(db) } + return nil } func NewBookmarkManager(config BookmarkManagerConfig) BookmarkManager { return &bookmarkManager{ - bookmarks: initializeBookmarks(config.InitialBookmarks), - supplier: config.BookmarkSupplier, - notifyUpdatesFn: config.BookmarkUpdateNotifier, + bookmarks: initializeBookmarks(config.InitialBookmarks), + supplier: config.BookmarkSupplier, + consumer: config.BookmarkConsumer, } } diff --git a/neo4j/bookmarks_test.go b/neo4j/bookmarks_test.go index adbee862..41c58e9f 100644 --- a/neo4j/bookmarks_test.go +++ b/neo4j/bookmarks_test.go @@ -1,6 +1,7 @@ package neo4j_test import ( + "context" "github.com/neo4j/neo4j-go-driver/v5/neo4j" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "testing" @@ -34,6 +35,8 @@ func TestCombineBookmarks(t *testing.T) { } func TestBookmarkManager(outer *testing.T) { + ctx := context.Background() + outer.Parallel() outer.Run("deduplicates initial bookmarks", func(t *testing.T) { @@ -44,11 +47,13 @@ func TestBookmarkManager(outer *testing.T) { }, }) - bookmarks1 := bookmarkManager.GetBookmarks("db1") + bookmarks1, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) expected1 := []string{"a", "b"} AssertEqualsInAnyOrder(t, bookmarks1, expected1) - bookmarks2 := bookmarkManager.GetBookmarks("db2") + bookmarks2, err := bookmarkManager.GetBookmarks(ctx, "db2") + AssertNoError(t, err) expected2 := []string{"b", "c"} AssertEqualsInAnyOrder(t, bookmarks2, expected2) }) @@ -56,7 +61,9 @@ func TestBookmarkManager(outer *testing.T) { outer.Run("gets no bookmarks by default", func(t *testing.T) { bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{}) getBookmarks := func(db string) bool { - return bookmarkManager.GetBookmarks(db) == nil + bookmarks, err := bookmarkManager.GetBookmarks(ctx, db) + AssertNoError(t, err) + return bookmarks == nil } if err := quick.Check(getBookmarks, nil); err != nil { @@ -68,16 +75,16 @@ func TestBookmarkManager(outer *testing.T) { expectedBookmarks := neo4j.Bookmarks{"a", "b", "c"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": {"a", "b"}}, - BookmarkSupplier: &simpleBookmarkSupplier{databaseBookmarks: func(db string) neo4j.Bookmarks { + BookmarkSupplier: &simpleBookmarkSupplier{databaseBookmarks: func(db string) (neo4j.Bookmarks, error) { if db != "db1" { t.Errorf("expected to supply bookmarks for db1, but got %s", db) } - return neo4j.Bookmarks{"b", "c"} + return neo4j.Bookmarks{"b", "c"}, nil }}, }) - actualBookmarks := bookmarkManager.GetBookmarks("db1") - + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) }) @@ -86,20 +93,21 @@ func TestBookmarkManager(outer *testing.T) { expectedBookmarks := []string{"a"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": {"a"}}, - BookmarkSupplier: &simpleBookmarkSupplier{databaseBookmarks: func(db string) neo4j.Bookmarks { + BookmarkSupplier: &simpleBookmarkSupplier{databaseBookmarks: func(db string) (neo4j.Bookmarks, error) { defer func() { calls++ }() if calls == 0 { - return neo4j.Bookmarks{"b"} + return neo4j.Bookmarks{"b"}, nil } - return nil + return nil, nil }}, }) - _ = bookmarkManager.GetBookmarks("db1") - actualBookmarks := bookmarkManager.GetBookmarks("db1") - + _, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) }) @@ -108,10 +116,12 @@ func TestBookmarkManager(outer *testing.T) { bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": {"a"}}, }) - bookmarks := bookmarkManager.GetBookmarks("db1") + bookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) bookmarks[0] = "changed" - bookmarks = bookmarkManager.GetBookmarks("db1") + bookmarks, err = bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, bookmarks, expectedBookmarks) }) @@ -123,10 +133,12 @@ func TestBookmarkManager(outer *testing.T) { }, }) - bookmarkManager.UpdateBookmarks("db1", []string{"b", "c"}, []string{"d", "a"}) + err := bookmarkManager.UpdateBookmarks(ctx, "db1", []string{"b", "c"}, []string{"d", "a"}) + AssertNoError(t, err) expectedBookmarks := []string{"a", "d"} - actualBookmarks := bookmarkManager.GetBookmarks("db1") + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) }) @@ -134,18 +146,21 @@ func TestBookmarkManager(outer *testing.T) { notifyHookCalled := false expectedBookmarks := []string{"a", "d"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ - BookmarkUpdateNotifier: func(db string, bookmarks neo4j.Bookmarks) { + BookmarkConsumer: func(_ context.Context, db string, bookmarks neo4j.Bookmarks) error { notifyHookCalled = true if db != "db1" { t.Errorf("expected to receive notifications for DB db1 but received notifications for %s", db) } AssertEqualsInAnyOrder(t, bookmarks, expectedBookmarks) + return nil }, }) - bookmarkManager.UpdateBookmarks("db1", nil, []string{"d", "a"}) + err := bookmarkManager.UpdateBookmarks(ctx, "db1", nil, []string{"d", "a"}) - actualBookmarks := bookmarkManager.GetBookmarks("db1") + AssertNoError(t, err) + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) if !notifyHookCalled { t.Errorf("notify hook should have been called") @@ -156,14 +171,17 @@ func TestBookmarkManager(outer *testing.T) { initialBookmarks := []string{"a", "b"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": initialBookmarks}, - BookmarkUpdateNotifier: func(db string, bookmarks neo4j.Bookmarks) { + BookmarkConsumer: func(_ context.Context, db string, bookmarks neo4j.Bookmarks) error { t.Error("I must not be called") + return nil }, }) - bookmarkManager.UpdateBookmarks("db1", initialBookmarks, nil) + err := bookmarkManager.UpdateBookmarks(ctx, "db1", initialBookmarks, nil) - actualBookmarks := bookmarkManager.GetBookmarks("db1") + AssertNoError(t, err) + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, initialBookmarks) }) @@ -172,18 +190,21 @@ func TestBookmarkManager(outer *testing.T) { expectedBookmarks := []string{"a", "d"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": {}}, - BookmarkUpdateNotifier: func(db string, bookmarks neo4j.Bookmarks) { + BookmarkConsumer: func(_ context.Context, db string, bookmarks neo4j.Bookmarks) error { notifyHookCalled = true if db != "db1" { t.Errorf("expected to receive notifications for DB db1 but received notifications for %s", db) } AssertEqualsInAnyOrder(t, bookmarks, expectedBookmarks) + return nil }, }) - bookmarkManager.UpdateBookmarks("db1", nil, []string{"d", "a"}) + err := bookmarkManager.UpdateBookmarks(ctx, "db1", nil, []string{"d", "a"}) - actualBookmarks := bookmarkManager.GetBookmarks("db1") + AssertNoError(t, err) + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) if !notifyHookCalled { t.Errorf("notify hook should have been called") @@ -195,18 +216,21 @@ func TestBookmarkManager(outer *testing.T) { expectedBookmarks := []string{"a", "d"} bookmarkManager := neo4j.NewBookmarkManager(neo4j.BookmarkManagerConfig{ InitialBookmarks: map[string]neo4j.Bookmarks{"db1": {"a", "b", "c"}}, - BookmarkUpdateNotifier: func(db string, bookmarks neo4j.Bookmarks) { + BookmarkConsumer: func(_ context.Context, db string, bookmarks neo4j.Bookmarks) error { notifyHookCalled = true if db != "db1" { t.Errorf("expected to receive notifications for DB db1 but received notifications for %s", db) } AssertEqualsInAnyOrder(t, bookmarks, expectedBookmarks) + return nil }, }) - bookmarkManager.UpdateBookmarks("db1", []string{"b", "c"}, []string{"d", "a"}) + err := bookmarkManager.UpdateBookmarks(ctx, "db1", []string{"b", "c"}, []string{"d", "a"}) - actualBookmarks := bookmarkManager.GetBookmarks("db1") + AssertNoError(t, err) + actualBookmarks, err := bookmarkManager.GetBookmarks(ctx, "db1") + AssertNoError(t, err) AssertEqualsInAnyOrder(t, actualBookmarks, expectedBookmarks) if !notifyHookCalled { t.Errorf("notify hook should have been called") @@ -222,14 +246,21 @@ func TestBookmarkManager(outer *testing.T) { }, }) - bookmarkManager.Forget("db", "par") + err := bookmarkManager.Forget(ctx, "db", "par") - allBookmarks := bookmarkManager.GetAllBookmarks() + AssertNoError(t, err) + allBookmarks, err := bookmarkManager.GetAllBookmarks(ctx) + AssertNoError(t, err) AssertEqualsInAnyOrder(t, allBookmarks, []string{"bar", "fighters"}) - AssertIntEqual(t, len(bookmarkManager.GetBookmarks("db")), 0) - AssertEqualsInAnyOrder(t, bookmarkManager.GetBookmarks("foo"), - []string{"bar", "fighters"}) - AssertIntEqual(t, len(bookmarkManager.GetBookmarks("par")), 0) + bookmarks, err := bookmarkManager.GetBookmarks(ctx, "db") + AssertNoError(t, err) + AssertIntEqual(t, len(bookmarks), 0) + bookmarks, err = bookmarkManager.GetBookmarks(ctx, "foo") + AssertNoError(t, err) + AssertEqualsInAnyOrder(t, bookmarks, []string{"bar", "fighters"}) + bookmarks, err = bookmarkManager.GetBookmarks(ctx, "par") + AssertNoError(t, err) + AssertIntEqual(t, len(bookmarks), 0) }) outer.Run("can forget untracked databases", func(t *testing.T) { @@ -239,26 +270,33 @@ func TestBookmarkManager(outer *testing.T) { }, }) - bookmarkManager.Forget("wat", "nope") + err := bookmarkManager.Forget(ctx, "wat", "nope") - allBookmarks := bookmarkManager.GetAllBookmarks() + AssertNoError(t, err) + allBookmarks, err := bookmarkManager.GetAllBookmarks(ctx) + AssertNoError(t, err) AssertEqualsInAnyOrder(t, allBookmarks, []string{"z", "cooper"}) - AssertEqualsInAnyOrder(t, bookmarkManager.GetBookmarks("db"), - []string{"z", "cooper"}) - AssertIntEqual(t, len(bookmarkManager.GetBookmarks("wat")), 0) - AssertIntEqual(t, len(bookmarkManager.GetBookmarks("nope")), 0) + bookmarks, err := bookmarkManager.GetBookmarks(ctx, "db") + AssertNoError(t, err) + AssertEqualsInAnyOrder(t, bookmarks, []string{"z", "cooper"}) + bookmarks, err = bookmarkManager.GetBookmarks(ctx, "wat") + AssertNoError(t, err) + AssertIntEqual(t, len(bookmarks), 0) + bookmarks, err = bookmarkManager.GetBookmarks(ctx, "nope") + AssertNoError(t, err) + AssertIntEqual(t, len(bookmarks), 0) }) } type simpleBookmarkSupplier struct { - allBookmarks func() neo4j.Bookmarks - databaseBookmarks func(string) neo4j.Bookmarks + allBookmarks func() (neo4j.Bookmarks, error) + databaseBookmarks func(string) (neo4j.Bookmarks, error) } -func (s *simpleBookmarkSupplier) GetAllBookmarks() neo4j.Bookmarks { +func (s *simpleBookmarkSupplier) GetAllBookmarks(context.Context) (neo4j.Bookmarks, error) { return s.allBookmarks() } -func (s *simpleBookmarkSupplier) GetBookmarks(db string) neo4j.Bookmarks { +func (s *simpleBookmarkSupplier) GetBookmarks(_ context.Context, db string) (neo4j.Bookmarks, error) { return s.databaseBookmarks(db) } diff --git a/neo4j/directrouter.go b/neo4j/directrouter.go index 091a62e2..12509502 100644 --- a/neo4j/directrouter.go +++ b/neo4j/directrouter.go @@ -38,11 +38,11 @@ func (r *directRouter) InvalidateReader(context.Context, string, string) error { return nil } -func (r *directRouter) Readers(context.Context, func() []string, string, log.BoltLogger) ([]string, error) { +func (r *directRouter) Readers(context.Context, func(context.Context) ([]string, error), string, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } -func (r *directRouter) Writers(context.Context, func() []string, string, log.BoltLogger) ([]string, error) { +func (r *directRouter) Writers(context.Context, func(context.Context) ([]string, error), string, log.BoltLogger) ([]string, error) { return []string{r.address}, nil } diff --git a/neo4j/driver_with_context.go b/neo4j/driver_with_context.go index b6a65609..50406807 100644 --- a/neo4j/driver_with_context.go +++ b/neo4j/driver_with_context.go @@ -236,10 +236,10 @@ type sessionRouter interface { // note: bookmarks are lazily supplied, only when a new routing table needs to be fetched // this is needed because custom bookmark managers may provide bookmarks from external systems // they should not be called when it is not needed (e.g. when a routing table is cached) - Readers(ctx context.Context, bookmarks func() []string, database string, boltLogger log.BoltLogger) ([]string, error) + Readers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) ([]string, error) // Writers returns the list of servers that can serve writes on the requested database. // note: bookmarks are lazily supplied, see Readers documentation to learn why - Writers(ctx context.Context, bookmarks func() []string, database string, boltLogger log.BoltLogger) ([]string, error) + Writers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) ([]string, error) // GetNameOfDefaultDatabase returns the name of the default database for the specified user. // The correct database name is needed when requesting readers or writers. // the bookmarks are eagerly provided since this method always fetches a new routing table diff --git a/neo4j/error.go b/neo4j/error.go index 46abc8ed..4d9bf44e 100644 --- a/neo4j/error.go +++ b/neo4j/error.go @@ -26,6 +26,7 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/retry" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" @@ -180,30 +181,5 @@ type ctxCloser interface { } func deferredClose(ctx context.Context, closer ctxCloser, prevErr error) error { - return combineErrors(prevErr, closer.Close(ctx)) -} - -func combineAllErrors(errs ...error) error { - count := len(errs) - if count == 0 { - return nil - } - err := errs[0] - if count == 1 { - return err - } - for i := 1; i < count; i++ { - err = combineErrors(err, errs[i]) - } - return err -} - -func combineErrors(err1, err2 error) error { - if err2 == nil { - return err1 - } - if err1 == nil { - return err2 - } - return fmt.Errorf("error %v occurred after previous error %w", err2, err1) + return errorutil.CombineErrors(prevErr, closer.Close(ctx)) } diff --git a/neo4j/internal/errorutil/errors.go b/neo4j/internal/errorutil/errors.go new file mode 100644 index 00000000..1959cc04 --- /dev/null +++ b/neo4j/internal/errorutil/errors.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package errorutil + +import ( + "fmt" +) + +func CombineAllErrors(errs ...error) error { + if len(errs) == 0 { + return nil + } + result := errs[0] + for _, err := range errs[1:] { + result = CombineErrors(result, err) + } + return result +} + +func CombineErrors(err1, err2 error) error { + if err2 == nil { + return err1 + } + if err1 == nil { + return err2 + } + return fmt.Errorf("error %v occurred after previous error %w", err2, err1) +} diff --git a/neo4j/internal/errorutil/errors_test.go b/neo4j/internal/errorutil/errors_test.go new file mode 100644 index 00000000..71592391 --- /dev/null +++ b/neo4j/internal/errorutil/errors_test.go @@ -0,0 +1,134 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package errorutil_test + +import ( + "fmt" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + "reflect" + "testing" +) + +func TestCombineErrors(outer *testing.T) { + + type testCase struct { + description string + input1 error + input2 error + output error + } + + err1 := fmt.Errorf("1") + err2 := fmt.Errorf("2") + + testCases := []testCase{ + { + description: "first non-nil", + input1: err1, + input2: nil, + output: err1, + }, + { + description: "second non-nil", + input1: nil, + input2: err2, + output: err2, + }, + { + description: "all nil", + input1: nil, + input2: nil, + output: nil, + }, + { + description: "all non-nil", + input1: err1, + input2: err2, + output: fmt.Errorf("error 2 occurred after previous error %w", err1), + }, + } + + for _, testCase := range testCases { + outer.Run(testCase.description, func(t *testing.T) { + output := errorutil.CombineErrors(testCase.input1, testCase.input2) + + if !reflect.DeepEqual(testCase.output, output) { + t.Errorf("expected %v, got %v", testCase.output, output) + } + }) + } +} + +func TestCombineAllErrors(outer *testing.T) { + + type testCase struct { + description string + input []error + output error + } + + err1 := fmt.Errorf("1") + err2 := fmt.Errorf("2") + err3 := fmt.Errorf("3") + + testCases := []testCase{ + { + description: "nil slice", + input: nil, + output: nil, + }, + { + description: "empty slice - variant 1", + input: []error{}, + output: nil, + }, + { + description: "empty slice - variant 2", + input: make([]error, 0), + output: nil, + }, + { + description: "slice with single non-nil element", + input: []error{err1}, + output: err1, + }, + { + description: "slice with all three non-nil elements", + input: []error{err1, err2, err3}, + output: errorutil.CombineErrors(errorutil.CombineErrors(err1, err2), err3), + }, + { + description: "slice with 1 nil element in the middle", + input: []error{err1, nil, err3}, + output: errorutil.CombineErrors(err1, err3), + }, + } + + outer.Parallel() + for _, testCase := range testCases { + outer.Run(testCase.description, func(t *testing.T) { + output := errorutil.CombineAllErrors(testCase.input...) + + if !reflect.DeepEqual(testCase.output, output) { + t.Errorf("expected %v, got %v", testCase.output, output) + } + }) + } +} diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index f56e9d81..fd207221 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -117,7 +117,7 @@ func (r *Router) readTable(ctx context.Context, dbRouter *databaseRouter, bookma return table, nil } -func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func() []string, database string, boltLogger log.BoltLogger) (*db.RoutingTable, error) { +func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) (*db.RoutingTable, error) { now := r.now() if !r.dbRoutersMut.TryLock(ctx) { @@ -130,7 +130,11 @@ func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func() []string return dbRouter.table, nil } - table, err := r.readTable(ctx, dbRouter, bookmarksFn(), database, "", boltLogger) + bookmarks, err := bookmarksFn(ctx) + if err != nil { + return nil, err + } + table, err := r.readTable(ctx, dbRouter, bookmarks, database, "", boltLogger) if err != nil { return nil, err } @@ -140,7 +144,7 @@ func (r *Router) getOrReadTable(ctx context.Context, bookmarksFn func() []string return table, nil } -func (r *Router) Readers(ctx context.Context, bookmarks func() []string, database string, boltLogger log.BoltLogger) ([]string, error) { +func (r *Router) Readers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) ([]string, error) { table, err := r.getOrReadTable(ctx, bookmarks, database, boltLogger) if err != nil { return nil, err @@ -170,7 +174,7 @@ func (r *Router) Readers(ctx context.Context, bookmarks func() []string, databas return table.Readers, nil } -func (r *Router) Writers(ctx context.Context, bookmarks func() []string, database string, boltLogger log.BoltLogger) ([]string, error) { +func (r *Router) Writers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, boltLogger log.BoltLogger) ([]string, error) { table, err := r.getOrReadTable(ctx, bookmarks, database, boltLogger) if err != nil { return nil, err diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index 4c7dec7f..1c5102bf 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -399,4 +399,4 @@ func TestCleanUp(t *testing.T) { } } -func nilBookmarks() []string { return nil } +func nilBookmarks(context.Context) ([]string, error) { return nil, nil } diff --git a/neo4j/internal/testutil/routerfake.go b/neo4j/internal/testutil/routerfake.go index e2ada790..5f3ee673 100644 --- a/neo4j/internal/testutil/routerfake.go +++ b/neo4j/internal/testutil/routerfake.go @@ -28,9 +28,9 @@ type RouterFake struct { Invalidated bool InvalidatedDb string ReadersRet []string - ReadersHook func(bookmarks func() []string, database string) ([]string, error) + ReadersHook func(bookmarks func(context.Context) ([]string, error), database string) ([]string, error) WritersRet []string - WritersHook func(bookmarks func() []string, database string) ([]string, error) + WritersHook func(bookmarks func(context.Context) ([]string, error), database string) ([]string, error) Err error CleanUpHook func() GetNameOfDefaultDbHook func(user string) (string, error) @@ -55,14 +55,14 @@ func (r *RouterFake) Invalidate(ctx context.Context, database string) error { return nil } -func (r *RouterFake) Readers(ctx context.Context, bookmarksFn func() []string, database string, log log.BoltLogger) ([]string, error) { +func (r *RouterFake) Readers(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, log log.BoltLogger) ([]string, error) { if r.ReadersHook != nil { return r.ReadersHook(bookmarksFn, database) } return r.ReadersRet, r.Err } -func (r *RouterFake) Writers(ctx context.Context, bookmarksFn func() []string, database string, log log.BoltLogger) ([]string, error) { +func (r *RouterFake) Writers(ctx context.Context, bookmarksFn func(context.Context) ([]string, error), database string, log log.BoltLogger) ([]string, error) { if r.WritersHook != nil { return r.WritersHook(bookmarksFn, database) } diff --git a/neo4j/session_bookmarks.go b/neo4j/session_bookmarks.go index d9f70a49..4b5cfd71 100644 --- a/neo4j/session_bookmarks.go +++ b/neo4j/session_bookmarks.go @@ -19,6 +19,8 @@ package neo4j +import "context" + type sessionBookmarks struct { bookmarkManager BookmarkManager bookmarks Bookmarks @@ -44,14 +46,17 @@ func (sb *sessionBookmarks) lastBookmark() string { return bookmarks[count-1] } -func (sb *sessionBookmarks) replaceBookmarks(database string, previousBookmarks []string, newBookmark string) { +func (sb *sessionBookmarks) replaceBookmarks(ctx context.Context, database string, previousBookmarks []string, newBookmark string) error { if len(newBookmark) == 0 { - return + return nil } if sb.bookmarkManager != nil { - sb.bookmarkManager.UpdateBookmarks(database, previousBookmarks, []string{newBookmark}) + if err := sb.bookmarkManager.UpdateBookmarks(ctx, database, previousBookmarks, []string{newBookmark}); err != nil { + return err + } } sb.replaceSessionBookmarks(newBookmark) + return nil } func (sb *sessionBookmarks) replaceSessionBookmarks(newBookmark string) { @@ -61,18 +66,18 @@ func (sb *sessionBookmarks) replaceSessionBookmarks(newBookmark string) { sb.bookmarks = []string{newBookmark} } -func (sb *sessionBookmarks) bookmarksOfDatabase(db string) Bookmarks { +func (sb *sessionBookmarks) bookmarksOfDatabase(ctx context.Context, db string) (Bookmarks, error) { if sb.bookmarkManager == nil { - return nil + return nil, nil } - return sb.bookmarkManager.GetBookmarks(db) + return sb.bookmarkManager.GetBookmarks(ctx, db) } -func (sb *sessionBookmarks) allBookmarks() Bookmarks { +func (sb *sessionBookmarks) allBookmarks(ctx context.Context) (Bookmarks, error) { if sb.bookmarkManager == nil { - return nil + return nil, nil } - return sb.bookmarkManager.GetAllBookmarks() + return sb.bookmarkManager.GetAllBookmarks(ctx) } // Remove empty string bookmarks to check for "bad" callers diff --git a/neo4j/session_bookmarks_test.go b/neo4j/session_bookmarks_test.go index 3decac8c..75d1dc48 100644 --- a/neo4j/session_bookmarks_test.go +++ b/neo4j/session_bookmarks_test.go @@ -20,11 +20,14 @@ package neo4j import ( + "context" "reflect" "testing" ) func TestSessionBookmarks(outer *testing.T) { + ctx := context.Background() + outer.Parallel() outer.Run("initial set bookmarks are cleaned up", func(t *testing.T) { @@ -49,8 +52,11 @@ func TestSessionBookmarks(outer *testing.T) { "", "bookmark", "", "deutschmark", "", }) - sessionBookmarks.replaceBookmarks("db", nil, "booking mark") + err := sessionBookmarks.replaceBookmarks(ctx, "db", nil, "booking mark") + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } currentBookmarks := sessionBookmarks.currentBookmarks() if !reflect.DeepEqual(currentBookmarks, []string{"booking mark"}) { t.Errorf(`expected bookmarks ["booking mark"], got %v`, currentBookmarks) @@ -64,8 +70,11 @@ func TestSessionBookmarks(outer *testing.T) { outer.Run("does not replace set bookmarks when new bookmark is empty", func(t *testing.T) { sessionBookmarks := newSessionBookmarks(nil, []string{"book marking"}) - sessionBookmarks.replaceBookmarks("db", nil, "") + err := sessionBookmarks.replaceBookmarks(ctx, "db", nil, "") + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } currentBookmarks := sessionBookmarks.currentBookmarks() if !reflect.DeepEqual(currentBookmarks, []string{"book marking"}) { t.Errorf(`expected bookmarks ["book marking"], got %v`, currentBookmarks) @@ -91,9 +100,12 @@ func TestSessionBookmarks(outer *testing.T) { bookmarkManager := &fakeBookmarkManager{} sessionBookmarks := newSessionBookmarks(bookmarkManager, nil) - sessionBookmarks.replaceBookmarks("dbz", []string{"b1", "b2"}, "b3") + err := sessionBookmarks.replaceBookmarks(ctx, "dbz", []string{"b1", "b2"}, "b3") - if !bookmarkManager.called(1, "UpdateBookmarks", "dbz", []string{"b1", "b2"}, []string{"b3"}) { + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if !bookmarkManager.called(1, "UpdateBookmarks", ctx, "dbz", []string{"b1", "b2"}, []string{"b3"}) { t.Errorf("Expected UpdateBookmarks to be called once but was not") } }) @@ -102,9 +114,9 @@ func TestSessionBookmarks(outer *testing.T) { bookmarkManager := &fakeBookmarkManager{} sessionBookmarks := newSessionBookmarks(bookmarkManager, nil) - sessionBookmarks.bookmarksOfDatabase("dbz") + _, _ = sessionBookmarks.bookmarksOfDatabase(ctx, "dbz") - if !bookmarkManager.called(1, "GetBookmarks", "dbz") { + if !bookmarkManager.called(1, "GetBookmarks", ctx, "dbz") { t.Errorf("Expected GetBookmarks to be called once but was not") } }) @@ -113,9 +125,9 @@ func TestSessionBookmarks(outer *testing.T) { bookmarkManager := &fakeBookmarkManager{} sessionBookmarks := newSessionBookmarks(bookmarkManager, nil) - sessionBookmarks.allBookmarks() + _, _ = sessionBookmarks.allBookmarks(ctx) - if !bookmarkManager.called(1, "GetAllBookmarks") { + if !bookmarkManager.called(1, "GetAllBookmarks", ctx) { t.Errorf("Expected GetBookmarks with the provided arguments to be called once but was not") } }) @@ -131,29 +143,32 @@ type fakeBookmarkManager struct { recordedCalls []invocation } -func (f *fakeBookmarkManager) UpdateBookmarks(database string, previousBookmarks, newBookmarks Bookmarks) { +func (f *fakeBookmarkManager) UpdateBookmarks(ctx context.Context, database string, previousBookmarks, newBookmarks Bookmarks) error { f.recordedCalls = append(f.recordedCalls, invocation{ function: "UpdateBookmarks", - arguments: []any{database, previousBookmarks, newBookmarks}, + arguments: []any{ctx, database, previousBookmarks, newBookmarks}, }) + return nil } -func (f *fakeBookmarkManager) GetBookmarks(database string) Bookmarks { +func (f *fakeBookmarkManager) GetBookmarks(ctx context.Context, database string) (Bookmarks, error) { f.recordedCalls = append(f.recordedCalls, invocation{ function: "GetBookmarks", - arguments: []any{database}, + arguments: []any{ctx, database}, }) - return nil + return nil, nil } -func (f *fakeBookmarkManager) GetAllBookmarks() Bookmarks { +func (f *fakeBookmarkManager) GetAllBookmarks(ctx context.Context) (Bookmarks, error) { f.recordedCalls = append(f.recordedCalls, invocation{ - function: "GetAllBookmarks", + function: "GetAllBookmarks", + arguments: []any{ctx}, }) - return nil + return nil, nil } -func (f *fakeBookmarkManager) Forget(...string) { +func (f *fakeBookmarkManager) Forget(context.Context, ...string) error { + return nil } func (f *fakeBookmarkManager) called(times int, function string, args ...any) bool { diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index 1e58dd54..c1bac9d2 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -24,6 +24,7 @@ import ( "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collection" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "math" "time" @@ -238,7 +239,11 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . } // Begin transaction - beginBookmarks := s.transactionBookmarks() + beginBookmarks, err := s.transactionBookmarks(ctx) + if err != nil { + s.pool.Return(ctx, conn) + return nil, wrapError(err) + } txHandle, err := conn.TxBegin(ctx, idb.TxConfig{ Mode: s.defaultMode, @@ -257,10 +262,11 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . conn: conn, fetchSize: s.fetchSize, txHandle: txHandle, - onClosed: func() { + onClosed: func(tx *explicitTransaction) { // On transaction closed (rolled back or committed) - s.retrieveBookmarks(conn, beginBookmarks) - s.pool.Return(ctx, conn) + bookmarkErr := s.retrieveBookmarks(ctx, conn, beginBookmarks) + poolErr := s.pool.Return(ctx, conn) + tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr, poolErr) s.explicitTx = nil }, } @@ -367,7 +373,11 @@ func (s *sessionWithContext) executeTransactionFunction( // handle transaction function panic as well defer s.pool.Return(ctx, conn) - beginBookmarks := s.transactionBookmarks() + beginBookmarks, err := s.transactionBookmarks(ctx) + if err != nil { + state.OnFailure(ctx, conn, err, false) + return true, nil + } txHandle, err := conn.TxBegin(ctx, idb.TxConfig{ Mode: mode, @@ -398,7 +408,11 @@ func (s *sessionWithContext) executeTransactionFunction( return true, nil } - s.retrieveBookmarks(conn, beginBookmarks) + // transaction has been committed so let's ignore (ie just log) the error + if err = s.retrieveBookmarks(ctx, conn, beginBookmarks); err != nil { + s.log.Warnf(log.Session, s.logId, "could not retrieve bookmarks after successful commit: %s\n"+ + "the results of this transaction may not be visible to subsequent operations", err.Error()) + } return false, x } @@ -452,16 +466,16 @@ func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessM return conn, nil } -func (s *sessionWithContext) retrieveBookmarks(conn idb.Connection, sentBookmarks Bookmarks) { +func (s *sessionWithContext) retrieveBookmarks(ctx context.Context, conn idb.Connection, sentBookmarks Bookmarks) error { if conn == nil { - return + return nil } bookmark, bookmarkDatabase := conn.Bookmark() db := s.databaseName if bookmarkDatabase != "" { db = bookmarkDatabase } - s.bookmarks.replaceBookmarks(db, sentBookmarks, bookmark) + return s.bookmarks.replaceBookmarks(ctx, db, sentBookmarks, bookmark) } func (s *sessionWithContext) retrieveSessionBookmarks(conn idb.Connection) { @@ -498,7 +512,11 @@ func (s *sessionWithContext) Run(ctx context.Context, return nil, err } - runBookmarks := s.transactionBookmarks() + runBookmarks, err := s.transactionBookmarks(ctx) + if err != nil { + s.pool.Return(ctx, conn) + return nil, wrapError(err) + } stream, err := conn.Run( ctx, idb.Command{ @@ -521,7 +539,10 @@ func (s *sessionWithContext) Run(ctx context.Context, s.autocommitTx = &autocommitTransaction{ conn: conn, res: newResultWithContext(conn, stream, cypher, params, func() { - s.retrieveBookmarks(conn, runBookmarks) + if err := s.retrieveBookmarks(ctx, conn, runBookmarks); err != nil { + s.log.Warnf(log.Session, s.logId, "could not retrieve bookmarks after result consumption: %s\n"+ + "the result of the initiating auto-commit transaction may not be visible to subsequent operations", err.Error()) + } }), onClosed: func() { s.pool.Return(ctx, conn) @@ -551,7 +572,7 @@ func (s *sessionWithContext) Close(ctx context.Context) error { go func() { routerErrChan <- s.router.CleanUp(ctx) }() - return combineAllErrors(txErr, <-poolErrChan, <-routerErrChan) + return errorutil.CombineAllErrors(txErr, <-poolErrChan, <-routerErrChan) } func (s *sessionWithContext) legacy() Session { @@ -584,7 +605,10 @@ func (s *sessionWithContext) resolveHomeDatabase(ctx context.Context) error { } // the actual database may not be known yet so the session initial bookmarks might actually belong to system - bookmarks := s.routingBookmarks() + bookmarks, err := s.routingBookmarks(ctx) + if err != nil { + return err + } defaultDb, err := s.router.GetNameOfDefaultDatabase(ctx, bookmarks, s.impersonatedUser, s.boltLogger) if err != nil { return err @@ -595,18 +619,25 @@ func (s *sessionWithContext) resolveHomeDatabase(ctx context.Context) error { return nil } -func (s *sessionWithContext) transactionBookmarks() Bookmarks { - result := collection.NewSet(s.bookmarks.allBookmarks()) +func (s *sessionWithContext) transactionBookmarks(ctx context.Context) (Bookmarks, error) { + bookmarks, err := s.bookmarks.allBookmarks(ctx) + if err != nil { + return nil, err + } + result := collection.NewSet(bookmarks) result.AddAll(s.bookmarks.currentBookmarks()) - return result.Values() + return result.Values(), nil } -func (s *sessionWithContext) routingBookmarks() Bookmarks { - systemBookmarks := s.bookmarks.bookmarksOfDatabase("system") +func (s *sessionWithContext) routingBookmarks(ctx context.Context) (Bookmarks, error) { + systemBookmarks, err := s.bookmarks.bookmarksOfDatabase(ctx, "system") + if err != nil { + return nil, err + } sessionBookmarks := s.bookmarks.currentBookmarks() bookmarks := collection.NewSet(systemBookmarks) bookmarks.AddAll(sessionBookmarks) - return bookmarks.Values() + return bookmarks.Values(), nil } type erroredSessionWithContext struct { diff --git a/neo4j/session_with_context_test.go b/neo4j/session_with_context_test.go index 5c60364d..c07d3c1b 100644 --- a/neo4j/session_with_context_test.go +++ b/neo4j/session_with_context_test.go @@ -154,7 +154,7 @@ func TestSession(outer *testing.T) { numDefaultDbLookups++ return mydb, nil } - router.WritersHook = func(bookmarks func() []string, database string) ([]string, error) { + router.WritersHook = func(_ func(context.Context) ([]string, error), database string) ([]string, error) { AssertStringEqual(t, mydb, database) return []string{"aserver"}, nil } @@ -337,7 +337,7 @@ func TestSession(outer *testing.T) { numDefaultDbLookups++ return mydb, nil } - router.ReadersHook = func(bookmarks func() []string, database string) ([]string, error) { + router.ReadersHook = func(_ func(context.Context) ([]string, error), database string) ([]string, error) { AssertStringEqual(t, mydb, database) return []string{"aserver"}, nil } @@ -505,7 +505,7 @@ func TestSession(outer *testing.T) { numDefaultDbLookups++ return mydb, nil } - router.ReadersHook = func(bookmarks func() []string, database string) ([]string, error) { + router.ReadersHook = func(_ func(context.Context) ([]string, error), database string) ([]string, error) { AssertStringEqual(t, mydb, database) return []string{"aserver"}, nil } @@ -621,7 +621,7 @@ func TestSession(outer *testing.T) { router, _, session := createSession() defer session.Close(ctx) expectedErr := fmt.Errorf("server retrieval err") - router.ReadersHook = func(func() []string, string) ([]string, error) { + router.ReadersHook = func(func(context.Context) ([]string, error), string) ([]string, error) { return nil, expectedErr } diff --git a/neo4j/transaction_with_context.go b/neo4j/transaction_with_context.go index e9940d06..09705405 100644 --- a/neo4j/transaction_with_context.go +++ b/neo4j/transaction_with_context.go @@ -57,7 +57,7 @@ type explicitTransaction struct { done bool runFailed bool err error - onClosed func() + onClosed func(*explicitTransaction) } func (tx *explicitTransaction) Run(ctx context.Context, cypher string, @@ -66,7 +66,7 @@ func (tx *explicitTransaction) Run(ctx context.Context, cypher string, if err != nil { tx.err = err tx.runFailed = true - tx.onClosed() + tx.onClosed(tx) return nil, wrapError(tx.err) } // no result consumption hook here since bookmarks are sent after commit, not after pulling results @@ -83,7 +83,7 @@ func (tx *explicitTransaction) Commit(ctx context.Context) error { } tx.err = tx.conn.TxCommit(ctx, tx.txHandle) tx.done = true - tx.onClosed() + tx.onClosed(tx) return wrapError(tx.err) } @@ -110,7 +110,7 @@ func (tx *explicitTransaction) Rollback(ctx context.Context) error { tx.err = tx.conn.TxRollback(ctx, tx.txHandle) } tx.done = true - tx.onClosed() + tx.onClosed(tx) return wrapError(tx.err) } diff --git a/testkit-backend/backend.go b/testkit-backend/backend.go index 6d0dfd89..d43ddc7f 100644 --- a/testkit-backend/backend.go +++ b/testkit-backend/backend.go @@ -1107,18 +1107,18 @@ func (b *backend) bookmarkManagerConfig(bookmarkManagerId string, supplierRegistered := config["bookmarksSupplierRegistered"] if supplierRegistered != nil && supplierRegistered.(bool) { result.BookmarkSupplier = &testkitBookmarkSupplier{ - supplierFn: b.supplyBookmarks(bookmarkManagerId), + supplier: b.supplyBookmarks(bookmarkManagerId), } } consumerRegistered := config["bookmarksConsumerRegistered"] if consumerRegistered != nil && consumerRegistered.(bool) { - result.BookmarkUpdateNotifier = b.consumeBookmarks(bookmarkManagerId) + result.BookmarkConsumer = b.consumeBookmarks(bookmarkManagerId) } return result } -func (b *backend) supplyBookmarks(bookmarkManagerId string) func(...string) neo4j.Bookmarks { - return func(databases ...string) neo4j.Bookmarks { +func (b *backend) supplyBookmarks(bookmarkManagerId string) func(...string) (neo4j.Bookmarks, error) { + return func(databases ...string) (neo4j.Bookmarks, error) { if len(databases) > 1 { panic("at most 1 database should be specified") } @@ -1130,13 +1130,13 @@ func (b *backend) supplyBookmarks(bookmarkManagerId string) func(...string) neo4 b.writeResponse("BookmarksSupplierRequest", msg) for { b.process() - return b.suppliedBookmarks[id] + return b.suppliedBookmarks[id], nil } } } -func (b *backend) consumeBookmarks(bookmarkManagerId string) func(string, neo4j.Bookmarks) { - return func(database string, bookmarks neo4j.Bookmarks) { +func (b *backend) consumeBookmarks(bookmarkManagerId string) func(context.Context, string, neo4j.Bookmarks) error { + return func(_ context.Context, database string, bookmarks neo4j.Bookmarks) error { id := b.nextId() b.writeResponse("BookmarksConsumerRequest", map[string]any{ "id": id, @@ -1147,22 +1147,22 @@ func (b *backend) consumeBookmarks(bookmarkManagerId string) func(string, neo4j. for { b.process() if _, found := b.consumedBookmarks[id]; found { - return + return nil } } } } type testkitBookmarkSupplier struct { - supplierFn func(...string) neo4j.Bookmarks + supplier func(...string) (neo4j.Bookmarks, error) } -func (t *testkitBookmarkSupplier) GetAllBookmarks() neo4j.Bookmarks { - return t.supplierFn() +func (t *testkitBookmarkSupplier) GetAllBookmarks(context.Context) (neo4j.Bookmarks, error) { + return t.supplier() } -func (t *testkitBookmarkSupplier) GetBookmarks(database string) neo4j.Bookmarks { - return t.supplierFn(database) +func (t *testkitBookmarkSupplier) GetBookmarks(_ context.Context, database string) (neo4j.Bookmarks, error) { + return t.supplier(database) } func convertInitialBookmarks(bookmarks map[string]any) map[string]neo4j.Bookmarks {