Skip to content

Commit

Permalink
Merge pull request #38 from smallstep/mariano/sql-wrap
Browse files Browse the repository at this point in the history
Allow to create a sequel.DB wrapping an sql.DB
  • Loading branch information
maraino authored Oct 11, 2024
2 parents 19d0825 + 344277a commit 3110667
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 12 deletions.
69 changes: 57 additions & 12 deletions sequel.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,26 @@ type DB struct {
}

type options struct {
Clock clock.Clock
DriverName string
RebindModel bool
Clock clock.Clock
DriverName string
RebindModel bool
MaxOpenConnections int
}

func newOptions(driverName string) *options {
return &options{
Clock: clock.New(),
DriverName: driverName,
RebindModel: false,
MaxOpenConnections: MaxOpenConnections,
}
}

func (o *options) apply(opts []Option) *options {
for _, fn := range opts {
fn(o)
}
return o
}

// Option is the type of options that can be used to modify the database. This
Expand Down Expand Up @@ -63,23 +80,25 @@ func WithRebindModel() Option {
}
}

// WithMaxOpenConnections sets the maximum number of open connections to the
// database. If it is not set it will use [MaxOpenConnections] (100).
func WithMaxOpenConnections(n int) Option {
return func(o *options) {
o.MaxOpenConnections = n
}
}

// New creates a new DB. It will fail if it cannot ping it.
func New(dataSourceName string, opts ...Option) (*DB, error) {
options := &options{
Clock: clock.New(),
DriverName: "pgx/v5",
RebindModel: false,
}
for _, fn := range opts {
fn(options)
}
options := newOptions("pgx/v5").apply(opts)

// Connect opens the database and verifies with a ping
db, err := sqlx.Connect(options.DriverName, dataSourceName)
if err != nil {
return nil, fmt.Errorf("error connecting to the database: %w", err)
}
db.SetMaxOpenConns(MaxOpenConnections)
db.SetMaxOpenConns(options.MaxOpenConnections)

return &DB{
db: db,
clock: options.Clock,
Expand All @@ -88,6 +107,27 @@ func New(dataSourceName string, opts ...Option) (*DB, error) {
}, nil
}

// NewDB creates a new DB wrapping the opened database handle with the given
// driverName. It will fail if it cannot ping it.
func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) {
options := newOptions(driverName).apply(opts)

// Wrap an opened *sql.DB and verify the connection with a ping
dbx := sqlx.NewDb(db, options.DriverName)
if err := dbx.Ping(); err != nil {
dbx.Close()
return nil, fmt.Errorf("error connecting to the database: %w", err)
}
dbx.SetMaxOpenConns(options.MaxOpenConnections)

return &DB{
db: dbx,
clock: options.Clock,
doRebindModel: options.RebindModel,
driverName: options.DriverName,
}, nil
}

type dbKey struct{}

// NewContext returns a new context with the given DB.
Expand Down Expand Up @@ -148,6 +188,11 @@ func (d *DB) Driver() string {
return d.driverName
}

// DB returns the embedded *sql.DB.
func (d *DB) DB() *sql.DB {
return d.db.DB
}

// Rebind transforms a query from `?` to the DB driver's bind type.
func (d *DB) Rebind(query string) string {
return d.db.Rebind(query)
Expand Down
58 changes: 58 additions & 0 deletions sequel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ func TestNew(t *testing.T) {
{"ok with clock", args{postgresDataSource, []Option{WithClock(clock.NewMock(time.Now()))}}, assert.NoError},
{"ok with driver", args{postgresDataSource, []Option{WithDriver("pgx/v5")}}, assert.NoError},
{"ok with rebindModel", args{postgresDataSource, []Option{WithRebindModel()}}, assert.NoError},
{"ok with maxConnections", args{postgresDataSource, []Option{WithMaxOpenConnections(10)}}, assert.NoError},
{"fail ping", args{strings.ReplaceAll(postgresDataSource, dbUser, "foo"), nil}, assert.Error},
}
for _, tt := range tests {
Expand All @@ -136,6 +137,49 @@ func TestNew(t *testing.T) {
}
}

func TestNewDB(t *testing.T) {
testTime := time.Now()

db, err := sql.Open("pgx/v5", postgresDataSource)
require.NoError(t, err)
closedDB, err := sql.Open("pgx/v5", postgresDataSource)
require.NoError(t, err)
require.NoError(t, closedDB.Close())

type args struct {
db *sql.DB
driverName string
opts []Option
}
tests := []struct {
name string
args args
want *DB
assertion assert.ErrorAssertionFunc
}{
{"ok", args{db, "pgx/v5", nil}, &DB{
db: sqlx.NewDb(db, "pgx/v5"),
clock: clock.New(),
doRebindModel: false,
driverName: "pgx/v5",
}, assert.NoError},
{"ok with options", args{db, "pgx/v5", []Option{WithClock(clock.NewMock(testTime)), WithDriver("pgx"), WithRebindModel()}}, &DB{
db: sqlx.NewDb(db, "pgx"),
clock: clock.NewMock(testTime),
doRebindModel: true,
driverName: "pgx",
}, assert.NoError},
{"fail ping", args{closedDB, "pgx/v5", nil}, nil, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewDB(tt.args.db, tt.args.driverName, tt.args.opts...)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestNewContext(t *testing.T) {
db, err := New(postgresDataSource)
require.NoError(t, err)
Expand Down Expand Up @@ -853,3 +897,17 @@ func TestDB_Driver(t *testing.T) {
assert.Equal(t, "pgx/v5", db.Driver())
assert.NoError(t, db.Close())
}

func TestDB_DB(t *testing.T) {
sdb, err := New(postgresDataSource)
require.NoError(t, err)
assert.Equal(t, sdb.db.DB, sdb.DB())
assert.NoError(t, sdb.Close())

db, err := sql.Open("pgx/v5", postgresDataSource)
require.NoError(t, err)
sdb, err = NewDB(db, "pgx/v5")
require.NoError(t, err)
assert.Equal(t, db, sdb.DB())
assert.NoError(t, sdb.Close())
}

0 comments on commit 3110667

Please sign in to comment.