diff --git a/util/errors/error.go b/util/errors/error.go index 648b0611..bf9af6eb 100644 --- a/util/errors/error.go +++ b/util/errors/error.go @@ -67,6 +67,8 @@ func NewSkip(reason any, skip int) error { } // Errorf is a shortcut for `errors.New(fmt.Errorf("format", args))`. +// Be careful when using this, this will result in losing the callers of +// the original error if one of the `args` is of type `*errors.Error`. func Errorf(format string, args ...any) error { return NewSkip(fmt.Errorf(format, args...), 3) } diff --git a/util/session/session.go b/util/session/session.go index 3cc9f6b9..52db45b4 100644 --- a/util/session/session.go +++ b/util/session/session.go @@ -100,7 +100,7 @@ type dbKey struct{} // The Gorm DB associated with this session is injected into the context as a value so `session.DB()` // can be used to retrieve it. func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) error { - tx := s.db.WithContext(ctx).Begin(s.TxOptions) + tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions) if tx.Error != nil { return errors.New(tx.Error) } diff --git a/util/session/session_test.go b/util/session/session_test.go index 0b68b542..5088a560 100644 --- a/util/session/session_test.go +++ b/util/session/session_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/utils/tests" "goyave.dev/goyave/v5/config" "goyave.dev/goyave/v5/database" @@ -122,12 +123,14 @@ func TestGormSession(t *testing.T) { ctx := context.WithValue(context.Background(), testKey{}, "testvalue") tx, err := session.Begin(ctx) + tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB assert.NoError(t, err) assert.NotNil(t, tx) subtx, err := session.Begin(tx.Context()) assert.NoError(t, err) assert.Equal(t, "testvalue", subtx.(Gorm).db.Statement.Context.Value(testKey{})) // Parent context is kept + assert.Contains(t, subtx.(Gorm).db.Statement.Clauses, "testclause") // Parent DB is used }) t.Run("Transaction", func(t *testing.T) { @@ -155,6 +158,30 @@ func TestGormSession(t *testing.T) { assert.False(t, committer.rolledback) }) + t.Run("Nested_Transaction", func(t *testing.T) { + db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) + if !assert.NoError(t, err) { + return + } + committer := &testCommitter{} + db.Statement.ConnPool = committer + session := GORM(db, nil) + + ctx := context.WithValue(context.Background(), testKey{}, "testvalue") + tx, err := session.Begin(ctx) + tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB + assert.NoError(t, err) + assert.NotNil(t, tx) + + err = session.Transaction(tx.Context(), func(ctx context.Context) error { + db := DB(ctx, nil) + assert.NotNil(t, db) + assert.Contains(t, db.Statement.Clauses, "testclause") // Parent DB is used + return nil + }) + assert.NoError(t, err) + }) + t.Run("TransactionError", func(t *testing.T) { db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) if !assert.NoError(t, err) {