Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-implement Lockable interface to drop requirement on driver impls. #33

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:

strategy:
matrix:
go-version: [1.19.x]
go-version: [1.22.x]

services:
postgres:
Expand Down Expand Up @@ -58,5 +58,4 @@ jobs:
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
- name: Execute Tests
run: | # we run driver and rest tests at the same time
go get -d -t ./...
make test
5 changes: 1 addition & 4 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,4 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.52.1

# Optional: if set to true then the action will use pre-installed Go.
skip-go-installation: true
version: v1.61.0
2 changes: 1 addition & 1 deletion drivers/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type Locker interface {
}

type Lockable interface {
DriverName() string
NewMutex(key string, logger Logger) (Locker, error)
}

// IsLockable returns whether the given instance satisfies
Expand Down
9 changes: 2 additions & 7 deletions drivers/mysql/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Mutex struct {
// NewMutex creates a mutex with the given key name.
//
// returns error if key is empty.
func NewMutex(key string, driver drivers.Driver, logger drivers.Logger) (*Mutex, error) {
func (driver *MySQL) NewMutex(key string, logger drivers.Logger) (*Mutex, error) {
key, err := drivers.MakeLockKey(key)
if err != nil {
return nil, err
Expand All @@ -43,12 +43,7 @@ func NewMutex(key string, driver drivers.Driver, logger drivers.Logger) (*Mutex,
ctx, cancel := context.WithTimeout(context.Background(), drivers.TTL)
defer cancel()

ms, ok := driver.(*mysql)
if !ok {
return nil, errors.New("incorrect implementation of the driver")
}

conn, err := ms.db.Conn(context.Background())
conn, err := driver.db.Conn(context.Background())
if err != nil {
return nil, err
}
Expand Down
28 changes: 12 additions & 16 deletions drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ type driverConfig struct {
closeDBonClose bool
}

type mysql struct {
type MySQL struct {
conn *sql.Conn
db *sql.DB
config *driverConfig
}

func WithInstance(dbInstance *sql.DB) (drivers.Driver, error) {
func WithInstance(dbInstance *sql.DB) (*MySQL, error) {
driverConfig := getDefaultConfig()

conn, err := dbInstance.Conn(context.Background())
Expand All @@ -48,10 +48,10 @@ func WithInstance(dbInstance *sql.DB) (drivers.Driver, error) {
return nil, err
}

return &mysql{config: driverConfig, conn: conn, db: dbInstance}, nil
return &MySQL{config: driverConfig, conn: conn, db: dbInstance}, nil
}

func Open(connURL string) (drivers.Driver, error) {
func Open(connURL string) (*MySQL, error) {
customParams, err := drivers.ExtractCustomParams(connURL, configParams)
if err != nil {
return nil, &drivers.AppError{Driver: driverName, OrigErr: err, Message: "failed to parse custom parameters from url"}
Expand Down Expand Up @@ -83,25 +83,21 @@ func Open(connURL string) (drivers.Driver, error) {

driverConfig.closeDBonClose = true

return &mysql{
return &MySQL{
conn: conn,
db: db,
config: driverConfig,
}, nil
}

func (driver *mysql) Ping() error {
func (driver *MySQL) Ping() error {
ctx, cancel := drivers.GetContext(driver.config.StatementTimeoutInSecs)
defer cancel()

return driver.conn.PingContext(ctx)
}

func (mysql) DriverName() string {
return driverName
}

func (driver *mysql) Close() error {
func (driver *MySQL) Close() error {
if driver.conn != nil {
if err := driver.conn.Close(); err != nil {
return &drivers.DatabaseError{
Expand Down Expand Up @@ -131,7 +127,7 @@ func (driver *mysql) Close() error {
return nil
}

func (driver *mysql) createSchemaTableIfNotExists() (err error) {
func (driver *MySQL) createSchemaTableIfNotExists() (err error) {
ctx, cancel := drivers.GetContext(driver.config.StatementTimeoutInSecs)
defer cancel()

Expand All @@ -149,7 +145,7 @@ func (driver *mysql) createSchemaTableIfNotExists() (err error) {
return nil
}

func (driver *mysql) Apply(migration *models.Migration, saveVersion bool) (err error) {
func (driver *MySQL) Apply(migration *models.Migration, saveVersion bool) (err error) {
query := migration.Query()
ctx, cancel := drivers.GetContext(driver.config.StatementTimeoutInSecs)
defer cancel()
Expand Down Expand Up @@ -206,7 +202,7 @@ func (driver *mysql) Apply(migration *models.Migration, saveVersion bool) (err e
return nil
}

func (driver *mysql) AppliedMigrations() (migrations []*models.Migration, err error) {
func (driver *MySQL) AppliedMigrations() (migrations []*models.Migration, err error) {
if driver.conn == nil {
return nil, &drivers.AppError{
OrigErr: errors.New("driver has no connection established"),
Expand Down Expand Up @@ -301,14 +297,14 @@ func mergeConfigWithParams(params map[string]string, config *driverConfig) (*dri
return config, nil
}

func (driver *mysql) addMigrationQuery(migration *models.Migration) string {
func (driver *MySQL) addMigrationQuery(migration *models.Migration) string {
if migration.Direction == models.Down {
return fmt.Sprintf("DELETE FROM %s WHERE (Version=%d AND NAME='%s')", driver.config.MigrationsTable, migration.Version, migration.Name)
}
return fmt.Sprintf("INSERT INTO %s (Version, Name) VALUES (%d, '%s')", driver.config.MigrationsTable, migration.Version, migration.Name)
}

func (driver *mysql) SetConfig(key string, value interface{}) error {
func (driver *MySQL) SetConfig(key string, value interface{}) error {
if driver.config != nil {
switch key {
case "StatementTimeoutInSecs":
Expand Down
36 changes: 14 additions & 22 deletions drivers/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (suite *MysqlTestSuite) AfterTest(_, _ string) {
}
}

func (suite *MysqlTestSuite) InitializeDriver(connURL string) (drivers.Driver, func()) {
func (suite *MysqlTestSuite) InitializeDriver(connURL string) (*MySQL, func()) {
connectedDriver, err := Open(connURL)
suite.Require().NoError(err, "should not error when connecting to database from url")
suite.Require().NotNil(connectedDriver)
Expand Down Expand Up @@ -112,49 +112,44 @@ func (suite *MysqlTestSuite) TestOpen() {
},
closeDBonClose: true, // we have created DB from DSN
}
mysqlDriver := connectedDriver.(*mysql)

suite.Assert().EqualValues(cfg, mysqlDriver.config)
suite.Assert().EqualValues(cfg, connectedDriver.config)
})

suite.T().Run("when connURL is valid can override migrations table", func(t *testing.T) {
connectedDriver, teardown := suite.InitializeDriver(testConnURL + "?x-migrations-table=test")
defer teardown()

mysqlDriver := connectedDriver.(*mysql)
suite.Assert().Equal("test", mysqlDriver.config.MigrationsTable)
suite.Assert().Equal("test", connectedDriver.config.MigrationsTable)
})

suite.T().Run("when connURL is valid can override statement timeout", func(t *testing.T) {
connectedDriver, teardown := suite.InitializeDriver(testConnURL + "?x-statement-timeout=10")
defer teardown()

mysqlDriver := connectedDriver.(*mysql)
suite.Assert().Equal(10, mysqlDriver.config.StatementTimeoutInSecs)
suite.Assert().Equal(10, connectedDriver.config.StatementTimeoutInSecs)
})

suite.T().Run("when connURL is valid can override max migration size", func(t *testing.T) {
connectedDriver, teardown := suite.InitializeDriver(testConnURL + "?x-migration-max-size=42")
defer teardown()

mysqlDriver := connectedDriver.(*mysql)
suite.Assert().Equal(42, mysqlDriver.config.MigrationMaxSize)
suite.Assert().Equal(42, connectedDriver.config.MigrationMaxSize)
})

suite.T().Run("when connURL is valid extracts database name", func(t *testing.T) {
connectedDriver, teardown := suite.InitializeDriver(testConnURL)
defer teardown()

mysqlDriver := connectedDriver.(*mysql)
suite.Assert().Equal(databaseName, mysqlDriver.config.databaseName)
suite.Assert().Equal(databaseName, connectedDriver.config.databaseName)
})
}

func (suite *MysqlTestSuite) TestCreateSchemaTableIfNotExists() {
defaultConfig := getDefaultConfig()

suite.T().Run("it errors when connection is missing", func(t *testing.T) {
driver := &mysql{}
driver := &MySQL{}

_, err := driver.AppliedMigrations()
suite.Assert().Error(err, "should error when database connection is missing")
Expand Down Expand Up @@ -372,12 +367,11 @@ func (suite *MysqlTestSuite) TestWithInstance() {
}()
suite.Assert().NoError(db.Ping(), "should not error when pinging the database")

driver, err := WithInstance(db)
mysqlDriver := driver.(*mysql)
mysqlDriver, err := WithInstance(db)
mysqlDriver.config.closeDBonClose = true
suite.Assert().NoError(err, "should not error when creating a driver from db instance")
defer func() {
err = driver.Close()
err = mysqlDriver.Close()
suite.Require().NoError(err, "should not error when closing the database connection")
}()

Expand All @@ -396,7 +390,7 @@ func (suite *MysqlTestSuite) TestLock() {

suite.T().Run("should create lock and unlock the mutex", func(t *testing.T) {
ctx := context.Background()
mx, err := NewMutex("test-lock-key", connectedDriver, logger)
mx, err := connectedDriver.NewMutex("test-lock-key", logger)
suite.Require().NoError(err, "should not error while creating the mutex")

err = mx.Lock(ctx)
Expand All @@ -409,12 +403,11 @@ func (suite *MysqlTestSuite) TestLock() {
suite.T().Run("should release the expired lock", func(t *testing.T) {
ctx := context.Background()

ms := connectedDriver.(*mysql)
query := fmt.Sprintf("INSERT INTO %s (Id, ExpireAt) VALUES (?, ?)", drivers.MutexTableName)
_, err := ms.conn.ExecContext(ctx, query, "test-lock-key", 1)
_, err := connectedDriver.conn.ExecContext(ctx, query, "test-lock-key", 1)
suite.Require().NoError(err, "should not error while manually inserting the mutex")

mx, err := NewMutex("test-lock-key", connectedDriver, logger)
mx, err := connectedDriver.NewMutex("test-lock-key", logger)
suite.Require().NoError(err, "should not error while creating the mutex")

err = mx.Lock(ctx)
Expand All @@ -430,18 +423,17 @@ func (suite *MysqlTestSuite) TestLock() {
now := time.Now()
timeout := time.After(2 * drivers.TTL) // should not wait to drop the lock for 30s

ms := connectedDriver.(*mysql)
query := fmt.Sprintf("INSERT INTO %s (Id, ExpireAt) VALUES (?, ?)", drivers.MutexTableName)
// set expiration 2 seconds later
_, err := ms.conn.ExecContext(ctx, query, "test-lock-key", now.Add(2*time.Second).Unix())
_, err := connectedDriver.conn.ExecContext(ctx, query, "test-lock-key", now.Add(2*time.Second).Unix())
suite.Require().NoError(err, "should not error while manually inserting the mutex")

done := make(chan struct{})
go func() {
defer func() {
close(done)
}()
mx, err := NewMutex("test-lock-key", connectedDriver, logger)
mx, err := connectedDriver.NewMutex("test-lock-key", logger)
suite.Require().NoError(err, "should not error while creating the mutex")

err = mx.Lock(ctx)
Expand Down
9 changes: 2 additions & 7 deletions drivers/postgres/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Mutex struct {
// NewMutex creates a mutex with the given key name.
//
// returns error if key is empty.
func NewMutex(key string, driver drivers.Driver, logger drivers.Logger) (*Mutex, error) {
func (pg *Postgres) NewMutex(key string, logger drivers.Logger) (*Mutex, error) {
key, err := drivers.MakeLockKey(key)
if err != nil {
return nil, err
Expand All @@ -43,12 +43,7 @@ func NewMutex(key string, driver drivers.Driver, logger drivers.Logger) (*Mutex,
ctx, cancel := context.WithTimeout(context.Background(), drivers.TTL)
defer cancel()

ps, ok := driver.(*postgres)
if !ok {
return nil, errors.New("incorrect implementation of the driver")
}

conn, err := ps.db.Conn(context.Background())
conn, err := pg.db.Conn(context.Background())
if err != nil {
return nil, err
}
Expand Down
Loading
Loading