Skip to content

Commit

Permalink
fix session & connection no being returned back to the pool in case t…
Browse files Browse the repository at this point in the history
…hat they cannot be recovered upon acquiring from pool.
  • Loading branch information
jxsl13 committed Mar 15, 2024
1 parent b773bf3 commit 736cc72
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
25 changes: 19 additions & 6 deletions pool/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,25 @@ func (cp *ConnectionPool) deriveConnection(ctx context.Context, id int64, cached
}

// GetConnection only returns an error upon shutdown
func (cp *ConnectionPool) GetConnection(ctx context.Context) (*Connection, error) {
func (cp *ConnectionPool) GetConnection(ctx context.Context) (conn *Connection, err error) {
select {
case conn, ok := <-cp.connections:
if !ok {
return nil, fmt.Errorf("connection pool %w", ErrClosed)
}
if conn.IsFlagged() {
err := conn.Recover(ctx)

// recovery may fail, that's why we MUST check for errors
// and return the connection back to the pool in case that the recovery failed
// due to e.g. the pool being closed, the context being canceled, etc.
defer func() {
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
cp.ReturnConnection(conn, err)
}
}()

err = conn.Recover(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}

return conn, nil
Expand All @@ -189,11 +197,16 @@ func (cp *ConnectionPool) nextTransientID() int64 {

// GetTransientConnection may return an error when the context was cancelled before the connection could be obtained.
// Transient connections may be returned to the pool. The are closed properly upon returning.
func (cp *ConnectionPool) GetTransientConnection(ctx context.Context) (_ *Connection, err error) {
conn, err := cp.deriveConnection(ctx, cp.nextTransientID(), false)
func (cp *ConnectionPool) GetTransientConnection(ctx context.Context) (conn *Connection, err error) {
conn, err = cp.deriveConnection(ctx, cp.nextTransientID(), false)
if err == nil {
return conn, nil
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()

// recover until context is closed
err = conn.Recover(ctx)
Expand Down
24 changes: 20 additions & 4 deletions pool/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ func newSessionPoolFromOption(pool *ConnectionPool, ctx context.Context, option
pool: pool,
autoCloseConnPool: option.AutoClosePool,

capacity: option.Capacity,
bufferCapacity: option.BufferCapacity,
confirmable: option.Confirmable,
capacity: option.Capacity,
sessions: make(chan *Session, option.Capacity),

ctx: ctx,
Expand Down Expand Up @@ -169,7 +169,7 @@ func (sp *SessionPool) Capacity() int {

// GetSession gets a pooled session.
// blocks until a session is acquired from the pool.
func (sp *SessionPool) GetSession(ctx context.Context) (*Session, error) {
func (sp *SessionPool) GetSession(ctx context.Context) (s *Session, err error) {
select {
case <-sp.catchShutdown():
return nil, sp.shutdownErr()
Expand All @@ -179,6 +179,13 @@ func (sp *SessionPool) GetSession(ctx context.Context) (*Session, error) {
if !ok {
return nil, fmt.Errorf("failed to get session: %w", ErrClosed)
}
defer func() {
// it's possible for the recovery to fail
// in that case we MUST return the session back to the pool
if err != nil {
sp.ReturnSession(session, err)
}
}()

err := session.Recover(ctx)
if err != nil {
Expand All @@ -191,14 +198,23 @@ func (sp *SessionPool) GetSession(ctx context.Context) (*Session, error) {
// GetTransientSession returns a transient session.
// This method may return an error when the context ha sbeen closed before a session could be obtained.
// A transient session creates a transient connection under the hood.
func (sp *SessionPool) GetTransientSession(ctx context.Context) (*Session, error) {
func (sp *SessionPool) GetTransientSession(ctx context.Context) (s *Session, err error) {
conn, err := sp.pool.GetTransientConnection(ctx)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
sp.pool.ReturnConnection(conn, err)
}
}()

transientID := atomic.AddInt64(&sp.transientID, 1)
return sp.deriveSession(ctx, conn, int(transientID))
s, err = sp.deriveSession(ctx, conn, int(transientID))
if err != nil {
return nil, err
}
return s, nil
}

func (sp *SessionPool) deriveSession(ctx context.Context, conn *Connection, id int) (*Session, error) {
Expand Down

0 comments on commit 736cc72

Please sign in to comment.