diff --git a/sequel.go b/sequel.go index 0470a90..95b4382 100644 --- a/sequel.go +++ b/sequel.go @@ -30,9 +30,26 @@ type DB struct { } type options struct { - Clock clock.Clock - DriverName string - RebindModel bool + Clock clock.Clock + DriverName string + RebindModel bool + MaxOpenConnections int +} + +func newOptions(driverName string) *options { + return &options{ + Clock: clock.New(), + DriverName: driverName, + RebindModel: false, + MaxOpenConnections: MaxOpenConnections, + } +} + +func (o *options) apply(opts []Option) *options { + for _, fn := range opts { + fn(o) + } + return o } // Option is the type of options that can be used to modify the database. This @@ -63,23 +80,25 @@ func WithRebindModel() Option { } } +// WithMaxOpenConnections sets the maximum number of open connections to the +// database. If it is not set it will use [MaxOpenConnections] (100). +func WithMaxOpenConnections(n int) Option { + return func(o *options) { + o.MaxOpenConnections = n + } +} + // New creates a new DB. It will fail if it cannot ping it. func New(dataSourceName string, opts ...Option) (*DB, error) { - options := &options{ - Clock: clock.New(), - DriverName: "pgx/v5", - RebindModel: false, - } - for _, fn := range opts { - fn(options) - } + options := newOptions("pgx/v5").apply(opts) // Connect opens the database and verifies with a ping db, err := sqlx.Connect(options.DriverName, dataSourceName) if err != nil { return nil, fmt.Errorf("error connecting to the database: %w", err) } - db.SetMaxOpenConns(MaxOpenConnections) + db.SetMaxOpenConns(options.MaxOpenConnections) + return &DB{ db: db, clock: options.Clock, @@ -88,6 +107,27 @@ func New(dataSourceName string, opts ...Option) (*DB, error) { }, nil } +// NewDB creates a new DB wrapping the opened database handle with the given +// driverName. It will fail if it cannot ping it. +func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { + options := newOptions(driverName).apply(opts) + + // Wrap an opened *sql.DB and verify the connection with a ping + dbx := sqlx.NewDb(db, options.DriverName) + if err := dbx.Ping(); err != nil { + dbx.Close() + return nil, fmt.Errorf("error connecting to the database: %w", err) + } + dbx.SetMaxOpenConns(options.MaxOpenConnections) + + return &DB{ + db: dbx, + clock: options.Clock, + doRebindModel: options.RebindModel, + driverName: options.DriverName, + }, nil +} + type dbKey struct{} // NewContext returns a new context with the given DB. @@ -148,6 +188,11 @@ func (d *DB) Driver() string { return d.driverName } +// DB returns the embedded *sql.DB. +func (d *DB) DB() *sql.DB { + return d.db.DB +} + // Rebind transforms a query from `?` to the DB driver's bind type. func (d *DB) Rebind(query string) string { return d.db.Rebind(query) diff --git a/sequel_test.go b/sequel_test.go index b67ccf1..06a9dd7 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -123,6 +123,7 @@ func TestNew(t *testing.T) { {"ok with clock", args{postgresDataSource, []Option{WithClock(clock.NewMock(time.Now()))}}, assert.NoError}, {"ok with driver", args{postgresDataSource, []Option{WithDriver("pgx/v5")}}, assert.NoError}, {"ok with rebindModel", args{postgresDataSource, []Option{WithRebindModel()}}, assert.NoError}, + {"ok with maxConnections", args{postgresDataSource, []Option{WithMaxOpenConnections(10)}}, assert.NoError}, {"fail ping", args{strings.ReplaceAll(postgresDataSource, dbUser, "foo"), nil}, assert.Error}, } for _, tt := range tests { @@ -136,6 +137,49 @@ func TestNew(t *testing.T) { } } +func TestNewDB(t *testing.T) { + testTime := time.Now() + + db, err := sql.Open("pgx/v5", postgresDataSource) + require.NoError(t, err) + closedDB, err := sql.Open("pgx/v5", postgresDataSource) + require.NoError(t, err) + require.NoError(t, closedDB.Close()) + + type args struct { + db *sql.DB + driverName string + opts []Option + } + tests := []struct { + name string + args args + want *DB + assertion assert.ErrorAssertionFunc + }{ + {"ok", args{db, "pgx/v5", nil}, &DB{ + db: sqlx.NewDb(db, "pgx/v5"), + clock: clock.New(), + doRebindModel: false, + driverName: "pgx/v5", + }, assert.NoError}, + {"ok with options", args{db, "pgx/v5", []Option{WithClock(clock.NewMock(testTime)), WithDriver("pgx"), WithRebindModel()}}, &DB{ + db: sqlx.NewDb(db, "pgx"), + clock: clock.NewMock(testTime), + doRebindModel: true, + driverName: "pgx", + }, assert.NoError}, + {"fail ping", args{closedDB, "pgx/v5", nil}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewDB(tt.args.db, tt.args.driverName, tt.args.opts...) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestNewContext(t *testing.T) { db, err := New(postgresDataSource) require.NoError(t, err) @@ -853,3 +897,17 @@ func TestDB_Driver(t *testing.T) { assert.Equal(t, "pgx/v5", db.Driver()) assert.NoError(t, db.Close()) } + +func TestDB_DB(t *testing.T) { + sdb, err := New(postgresDataSource) + require.NoError(t, err) + assert.Equal(t, sdb.db.DB, sdb.DB()) + assert.NoError(t, sdb.Close()) + + db, err := sql.Open("pgx/v5", postgresDataSource) + require.NoError(t, err) + sdb, err = NewDB(db, "pgx/v5") + require.NoError(t, err) + assert.Equal(t, db, sdb.DB()) + assert.NoError(t, sdb.Close()) +}