From a219acca4b70a4e55765e4500491a82b3b57cbb3 Mon Sep 17 00:00:00 2001 From: qqxhb <1252905006@qq.com> Date: Fri, 18 Aug 2023 18:55:58 +0800 Subject: [PATCH 1/2] feat: rm contextconnpool method --- go.sum | 2 -- gorm.go | 7 ++++--- interfaces.go | 6 ------ prepare_stmt.go | 23 ++++++++++++++++++----- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/go.sum b/go.sum index fb4240eb4..bd6104c9b 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 203527af3..61633c335 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -373,10 +374,10 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { - connPool := db.ConnPool + connPool := db.Statement.ConnPool - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/interfaces.go b/interfaces.go index 1950d7400..3bcc3d570 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/prepare_stmt.go b/prepare_stmt.go index 9d98c86e0..aa944624c 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } - if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(gormdb) - } - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } @@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -176,6 +185,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() From 35d2ba9f087b9608dc2abe29121d132e424ee2f2 Mon Sep 17 00:00:00 2001 From: qqxhb <1252905006@qq.com> Date: Fri, 18 Aug 2023 19:33:38 +0800 Subject: [PATCH 2/2] feat: nil --- gorm.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index 61633c335..32193870c 100644 --- a/gorm.go +++ b/gorm.go @@ -374,8 +374,10 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { - connPool := db.Statement.ConnPool - + connPool := db.ConnPool + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } if tx, ok := connPool.(*sql.Tx); ok && tx != nil { return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil }