diff --git a/txnsql/sql.go b/txnsql/sql.go index e4caaaf..d4e62c0 100644 --- a/txnsql/sql.go +++ b/txnsql/sql.go @@ -12,8 +12,8 @@ import ( type SqlAdapter interface { txn.Adapter - // IsTx returns true if the current transaction is active. - IsTx() bool + // Returns current transaction if it exists. + Tx() *sql.Tx } // New creates a new SqlAdapter instance using the provided *sql.DB. @@ -59,6 +59,6 @@ func (a *sqlAdapter) End(_ context.Context) { } } -func (a *sqlAdapter) IsTx() bool { - return a.tx != nil +func (a *sqlAdapter) Tx() *sql.Tx { + return a.tx } diff --git a/txnsql/sql_test.go b/txnsql/sql_test.go index 6a96796..ebbfafc 100644 --- a/txnsql/sql_test.go +++ b/txnsql/sql_test.go @@ -141,23 +141,23 @@ func TestSqlAdapter_End(t *testing.T) { }) } -func TestSqlAdapter_IsTx(t *testing.T) { +func TestSqlAdapter_Tx(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() - t.Run("IsTx true", func(t *testing.T) { + t.Run("Tx is not nil", func(t *testing.T) { adapter := &sqlAdapter{db: db} mock.ExpectBegin() adapter.Begin(context.Background()) - assert.True(t, adapter.IsTx()) + assert.True(t, adapter.Tx() != nil) }) - t.Run("IsTx false", func(t *testing.T) { + t.Run("Tx is nil", func(t *testing.T) { adapter := &sqlAdapter{db: db} - assert.False(t, adapter.IsTx()) + assert.True(t, adapter.Tx() == nil) }) }