Skip to content

Commit

Permalink
fix memory leaks in PrepareStatementDB (#7142)
Browse files Browse the repository at this point in the history
* fix memory leaks in PrepareStatementDB

* Fix CR:
1) Fix potential Segmentation Fault in Reset function
2) Setting db.Stmts to nil map when Close to avoid further using

* Add Test:
1) TestPreparedStmtConcurrentReset
2) TestPreparedStmtConcurrentClose

* Fix test, create new connection to keep away from other tests

---------

Co-authored-by: Zehui Chen <[email protected]>
  • Loading branch information
ivila and Zehui Chen authored Aug 22, 2024
1 parent 4a50b36 commit 0dbfda5
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 16 deletions.
44 changes: 28 additions & 16 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ type Stmt struct {
}

type PreparedStmtDB struct {
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
Stmts map[string]*Stmt
Mux *sync.RWMutex
ConnPool
}

func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
}
}

Expand All @@ -48,22 +46,32 @@ func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()

for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
for _, stmt := range db.Stmts {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
}

func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()

for _, stmt := range sdb.Stmts {
go stmt.Close()
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}

Expand Down Expand Up @@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact

return *stmt, nil
}

// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
Expand All @@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact

db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()

return cacheStmt, nil
Expand Down
147 changes: 147 additions & 0 deletions tests/prepared_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) {
t.Fatalf("prepared stmt should be empty")
}
}

func isUsingClosedConnError(err error) bool {
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
return err.Error() == "sql: statement is closed"
}

// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentReset(t *testing.T) {
name := "prepared_stmt_concurrent_reset"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}

// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}

loopCount := 100
var wg sync.WaitGroup
var unexpectedError bool
writerFinish := make(chan struct{})

wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(writerFinish)

for j := 0; j < loopCount; j++ {
var tmp User
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil || isUsingClosedConnError(err) {
continue
}
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
unexpectedError = true
break
}
}(user.ID)

wg.Add(1)
go func() {
defer wg.Done()
<-writerFinish
pdb.Reset()
}()

wg.Wait()

if unexpectedError {
t.Fatalf("should is a unexpected error")
}
}

// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// for example: one goroutine found error and just close the database, and others are executing SQL
// this test making sure that the gorm would not get a Segmentation Fault,
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
// and all of the goroutine must got gorm.ErrInvalidDB after database close
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}

// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}

loopCount := 100
var wg sync.WaitGroup
var lastErr error
closeValid := make(chan struct{}, loopCount)
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
var lastRunIndex int
var closeFinishedAt int64

wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(closeValid)
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
if lastRunIndex == closeStartIdx {
closeValid <- struct{}{}
}
var tmp User
now := time.Now().UnixNano()
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil {
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
lastErr = errors.New("must got error after database closed")
break
}
continue
}
lastErr = err
break
}
}(user.ID)

wg.Add(1)
go func() {
defer wg.Done()
for range closeValid {
for i := 0; i < loopCount; i++ {
pdb.Close() // the Close method must can be call multiple times
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
}
}
}()

wg.Wait()
var tmp User
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
if err != gorm.ErrInvalidDB {
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
}

// must be error
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
}
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
}
if pdb.Stmts != nil {
t.Fatalf("stmts must be nil")
}
}

0 comments on commit 0dbfda5

Please sign in to comment.