From d8467678992a9698aed124d5432faf8c5895773e Mon Sep 17 00:00:00 2001 From: Rancho Date: Fri, 30 Dec 2022 17:10:44 +0800 Subject: [PATCH 1/2] code improve --- internal/adapter/repository/mysql.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/adapter/repository/mysql.go b/internal/adapter/repository/mysql.go index 6a112ab..f41f592 100644 --- a/internal/adapter/repository/mysql.go +++ b/internal/adapter/repository/mysql.go @@ -3,7 +3,7 @@ package repository import ( "context" "fmt" - buitin_log "log" + buitinLog "log" "os" "time" @@ -25,7 +25,7 @@ import ( func buildGormConfig() *gorm.Config { logger := gormLogger.New( - buitin_log.New(os.Stdout, "\r\n", buitin_log.LstdFlags), // io writer + buitinLog.New(os.Stdout, "\r\n", buitinLog.LstdFlags), gormLogger.Config{ SlowThreshold: time.Second, // Slow SQL threshold LogLevel: gormLogger.Info, // Log level @@ -60,10 +60,10 @@ func (c *MySQL) Close(ctx context.Context) { if sqlDB != nil { err := sqlDB.Close() if err != nil { - log.SugaredLogger.Errorf("close mysql client fail. err: %v", err) + log.SugaredLogger.Errorf("close MySQL fail. err: %v", err) } } - log.Logger.Info("mysql client closed") + log.Logger.Info("MySQL closed") } func (c *MySQL) MockClient() (*gorm.DB, sqlmock.Sqlmock) { @@ -71,13 +71,13 @@ func (c *MySQL) MockClient() (*gorm.DB, sqlmock.Sqlmock) { if err != nil { panic("mock MySQL fail, err: " + err.Error()) } - dialector := driver.New(driver.Config{ + dialect := driver.New(driver.Config{ Conn: sqlDB, DriverName: "mysql-mock", SkipInitializeWithVersion: true, }) - c.db, err = gorm.Open(dialector, buildGormConfig()) + c.db, err = gorm.Open(dialect, buildGormConfig()) return c.db, mock } @@ -93,7 +93,7 @@ func openGormDB() (*gorm.DB, error) { config.Config.MySQL.ParseTime, config.Config.MySQL.TimeZone, ) - dialector = driver.New(driver.Config{ + dialect = driver.New(driver.Config{ DSN: dsn, DriverName: "mysql", DefaultStringSize: 255, @@ -111,8 +111,7 @@ func openGormDB() (*gorm.DB, error) { }) ) - db, err := gorm.Open(dialector, buildGormConfig()) - + db, err := gorm.Open(dialect, buildGormConfig()) if err != nil { return nil, err } From 92d0a0548220750e3dad439d905ab12b41084c23 Mon Sep 17 00:00:00 2001 From: Rancho Date: Fri, 30 Dec 2022 18:13:45 +0800 Subject: [PATCH 2/2] remove transaction tmpl type --- internal/adapter/repository/mysql.go | 49 ++++++------ .../repository/mysql/entity/example.go | 60 ++++---------- .../repository/mysql/entity/example_test.go | 12 +-- .../repository/mysql/entity/main_test.go | 2 +- .../adapter/repository/mysql/transaction.go | 78 ------------------- internal/adapter/repository/redis.go | 11 ++- internal/adapter/repository/repository.go | 61 +++++++++++---- .../adapter/repository/repository_test.go | 40 +++++++++- 8 files changed, 134 insertions(+), 179 deletions(-) delete mode 100644 internal/adapter/repository/mysql/transaction.go diff --git a/internal/adapter/repository/mysql.go b/internal/adapter/repository/mysql.go index f41f592..4a5a93e 100644 --- a/internal/adapter/repository/mysql.go +++ b/internal/adapter/repository/mysql.go @@ -3,7 +3,7 @@ package repository import ( "context" "fmt" - buitinLog "log" + builtinLog "log" "os" "time" @@ -23,30 +23,18 @@ import ( * @date 2021/12/21 */ -func buildGormConfig() *gorm.Config { - logger := gormLogger.New( - buitinLog.New(os.Stdout, "\r\n", buitinLog.LstdFlags), - gormLogger.Config{ - SlowThreshold: time.Second, // Slow SQL threshold - LogLevel: gormLogger.Info, // Log level - IgnoreRecordNotFoundError: false, // Ignore ErrRecordNotFound error for logger - Colorful: true, // Disable color - }, - ) - // logger := zapgorm2.New(log.Logger) - // logger.SetAsDefault() - // logger.LogMode(gormLogger.Info) - - return &gorm.Config{ - NamingStrategy: schema.NamingStrategy{SingularTable: true}, - Logger: logger, - } -} - type MySQL struct { db *gorm.DB } +func NewMySQLClient() *MySQL { + db, err := openGormDB() + if err != nil { + panic(err) + } + return &MySQL{db: db} +} + func (c *MySQL) GetDB(ctx context.Context) *gorm.DB { return c.db.WithContext(ctx) } @@ -128,10 +116,19 @@ func openGormDB() (*gorm.DB, error) { return db, nil } -func NewMySQLClient() *MySQL { - db, err := openGormDB() - if err != nil { - panic(err) +func buildGormConfig() *gorm.Config { + logger := gormLogger.New( + builtinLog.New(os.Stdout, "\r\n", builtinLog.LstdFlags), + gormLogger.Config{ + SlowThreshold: time.Second, // Slow SQL threshold + LogLevel: gormLogger.Info, // Log level + IgnoreRecordNotFoundError: false, // Ignore ErrRecordNotFound error for logger + Colorful: true, // Disable color + }, + ) + + return &gorm.Config{ + NamingStrategy: schema.NamingStrategy{SingularTable: true}, + Logger: logger, } - return &MySQL{db: db} } diff --git a/internal/adapter/repository/mysql/entity/example.go b/internal/adapter/repository/mysql/entity/example.go index 5d519bc..107716b 100644 --- a/internal/adapter/repository/mysql/entity/example.go +++ b/internal/adapter/repository/mysql/entity/example.go @@ -10,7 +10,6 @@ import ( "gorm.io/gorm" "go-hexagonal/internal/adapter/repository" - "go-hexagonal/internal/adapter/repository/mysql" "go-hexagonal/internal/domain/model" "go-hexagonal/internal/domain/repo" ) @@ -25,14 +24,13 @@ func NewExample() *Example { } type Example struct { - mysql.TransactionImpl `structs:"-"` // inheritance mysql transaction implement - Id int `json:"id" gorm:"primarykey" structs:",omitempty,underline"` - Name string `json:"name" structs:",omitempty,underline"` - Alias string `json:"alias" structs:",omitempty,underline"` - CreatedAt time.Time `json:"created_at" structs:",omitempty,underline"` - UpdatedAt time.Time `json:"updated_at" structs:",omitempty,underline"` - DeletedAt gorm.DeletedAt `json:"deleted_at" structs:",omitempty,underline"` - ChangeMap map[string]interface{} `json:"-" gorm:"-" structs:"-"` + Id int `json:"id" gorm:"primarykey" structs:",omitempty,underline"` + Name string `json:"name" structs:",omitempty,underline"` + Alias string `json:"alias" structs:",omitempty,underline"` + CreatedAt time.Time `json:"created_at" structs:",omitempty,underline"` + UpdatedAt time.Time `json:"updated_at" structs:",omitempty,underline"` + DeletedAt gorm.DeletedAt `json:"deleted_at" structs:",omitempty,underline"` + ChangeMap map[string]interface{} `json:"-" gorm:"-" structs:"-"` } func (e Example) TableName() string { @@ -52,13 +50,7 @@ func (e *Example) Create(ctx context.Context, tr *repository.Transaction, model return nil, errors.Wrap(err, "copier fail") } - // conn db - db, err := e.ConnDB(ctx, tr) - if err != nil { - return nil, err - } - - // handle sql + db := tr.Conn(ctx) err = db.Create(entity).Error if err != nil { return nil, err @@ -75,13 +67,7 @@ func (e *Example) Create(ctx context.Context, tr *repository.Transaction, model func (e *Example) Delete(ctx context.Context, tr *repository.Transaction, id int) (err error) { entity := &Example{} - // conn db - db, err := e.ConnDB(ctx, tr) - if err != nil { - return err - } - - // handle sql + db := tr.Conn(ctx) err = db.Delete(entity, id).Error // hard delete // err := tx.Unscoped().Delete(entity, Id).Error @@ -97,14 +83,8 @@ func (e *Example) Update(ctx context.Context, tr *repository.Transaction, model entity.ChangeMap = structs.Map(entity) entity.ChangeMap["updated_at"] = time.Now() - // conn db - db, err := e.ConnDB(ctx, tr) - if err != nil { - return err - } - - // handle sql - db.Table(entity.TableName()).Where("id = ? AND deleted_at IS NULL", entity.Id).Updates(entity.ChangeMap) + db := tr.Conn(ctx) + db = db.Table(entity.TableName()).Where("id = ? AND deleted_at IS NULL", entity.Id).Updates(entity.ChangeMap) return db.Error } @@ -112,14 +92,8 @@ func (e *Example) Update(ctx context.Context, tr *repository.Transaction, model func (e *Example) GetByID(ctx context.Context, tr *repository.Transaction, id int) (domain *model.Example, err error) { entity := &Example{} - // conn db - db, err := e.ConnDB(ctx, tr) - if err != nil { - return nil, err - } - - // handle sql - db.Table(entity.TableName()).Find(entity, id) + db := tr.Conn(ctx) + db = db.Table(entity.TableName()).Find(entity, id) if db.Error != nil { return nil, err @@ -136,13 +110,7 @@ func (e *Example) GetByID(ctx context.Context, tr *repository.Transaction, id in func (e *Example) FindByName(ctx context.Context, tr *repository.Transaction, name string) (model *model.Example, err error) { entity := &Example{} - // conn db - db, err := e.ConnDB(ctx, tr) - if err != nil { - return nil, err - } - - // handle sql + db := tr.Conn(ctx) db.Table(entity.TableName()).Where("name = ?", name).Last(entity) if db.Error != nil { return nil, err diff --git a/internal/adapter/repository/mysql/entity/example_test.go b/internal/adapter/repository/mysql/entity/example_test.go index 7f69eec..8c876e8 100644 --- a/internal/adapter/repository/mysql/entity/example_test.go +++ b/internal/adapter/repository/mysql/entity/example_test.go @@ -10,7 +10,6 @@ import ( "go-hexagonal/api/dto" "go-hexagonal/internal/adapter/repository" - "go-hexagonal/internal/adapter/repository/mysql" "go-hexagonal/internal/domain/model" ) @@ -48,10 +47,13 @@ func TestExample_Create(t *testing.T) { Name: "rancho", Alias: "cooper", } - tr := mysql.NewTransaction(ctx, &sql.TxOptions{ - Isolation: sql.LevelReadUncommitted, - ReadOnly: false, - }) + tr := repository.NewTransaction(ctx, + repository.MySQLStore, + &sql.TxOptions{ + Isolation: sql.LevelReadUncommitted, + ReadOnly: false, + }, + ) example, err := exampleRepo.Create(ctx, tr, e) assert.NoError(t, err) assert.NotEmpty(t, example.Id) diff --git a/internal/adapter/repository/mysql/entity/main_test.go b/internal/adapter/repository/mysql/entity/main_test.go index 2db754a..f380c91 100644 --- a/internal/adapter/repository/mysql/entity/main_test.go +++ b/internal/adapter/repository/mysql/entity/main_test.go @@ -22,7 +22,7 @@ func TestMain(m *testing.M) { config.Init() log.Init() - repository.Clients.MySQL = repository.NewMySQLClient() + repository.Init(repository.WithMySQL()) _ = repository.Clients.MySQL.GetDB(ctx).AutoMigrate(&Example{}) m.Run() } diff --git a/internal/adapter/repository/mysql/transaction.go b/internal/adapter/repository/mysql/transaction.go deleted file mode 100644 index e62f626..0000000 --- a/internal/adapter/repository/mysql/transaction.go +++ /dev/null @@ -1,78 +0,0 @@ -package mysql - -import ( - "context" - "database/sql" - - "github.com/pkg/errors" - "gorm.io/gorm" - - "go-hexagonal/internal/adapter/repository" -) - -/** - * @author Rancho - * @date 2022/12/30 - */ - -type TransactionImpl struct { -} - -func NewTransaction(ctx context.Context, opt *sql.TxOptions) *repository.Transaction { - session := repository.Clients.MySQL.GetDB(ctx) - if opt != nil { - session = session.Begin(opt) - } - - return &repository.Transaction{ - Session: session, - TxOpt: opt, - } -} - -func (t TransactionImpl) ConnDB(ctx context.Context, tr *repository.Transaction) (db *gorm.DB, err error) { - if tr == nil { - // init transaction with default session - tr = &repository.Transaction{Session: repository.Clients.MySQL.GetDB(ctx)} - } - if tr.Session == nil { - // begin new with TxOpt - tr.Session = repository.Clients.MySQL.GetDB(ctx).Begin(tr.TxOpt) - } - - return tr.Session, err -} - -func (t TransactionImpl) Begin(ctx context.Context, tr *repository.Transaction) { - if tr == nil { - tr = &repository.Transaction{} - } - if tr.Session == nil { - tr.Session = repository.Clients.MySQL.GetDB(ctx).Begin(tr.TxOpt) - } - -} - -func (t TransactionImpl) Commit(tr *repository.Transaction) error { - if tr == nil { - return errors.New("Commit with nil tr") - } - if tr.Session == nil { - return errors.New("Commit with nil tr.Session") - } - - return tr.Session.Commit().Error -} - -func (t TransactionImpl) Rollback(tr *repository.Transaction) error { - if tr == nil { - return errors.New("Rollback with nil tr") - } - if tr.Session == nil { - return errors.New("Rollback with nil tr.Session") - } - - return tr.Session.Rollback().Error -} - -var _ repository.ITransaction = &TransactionImpl{} diff --git a/internal/adapter/repository/redis.go b/internal/adapter/repository/redis.go index f3ab061..45eb53f 100644 --- a/internal/adapter/repository/redis.go +++ b/internal/adapter/repository/redis.go @@ -24,6 +24,10 @@ type Redis struct { db *redis.Client } +func NewRedisClient() *Redis { + return &Redis{db: newRedisConn()} +} + func (r *Redis) GetClient() *redis.Client { return r.db } @@ -36,8 +40,7 @@ func (r *Redis) Close(ctx context.Context) { log.Logger.Info("redis client closed") } -func (r *Redis) MockClient() redismock.ClusterClientMock { - // FIXME unverified +func (r *Redis) MockClient() redismock.ClientMock { db, mock := redismock.NewClientMock() r.db = db return mock @@ -54,7 +57,3 @@ func newRedisConn() *redis.Client { IdleTimeout: time.Duration(config.Config.Redis.IdleTimeout) * time.Second, }) } - -func NewRedisClient() *Redis { - return &Redis{db: newRedisConn()} -} diff --git a/internal/adapter/repository/repository.go b/internal/adapter/repository/repository.go index 0955480..7d931a4 100644 --- a/internal/adapter/repository/repository.go +++ b/internal/adapter/repository/repository.go @@ -15,14 +15,16 @@ var Clients = &clients{} type Transaction struct { Session *gorm.DB TxOpt *sql.TxOptions - // Tx *sql.Tx } -type ITransaction interface { - Begin(context.Context, *Transaction) - Commit(*Transaction) error - Rollback(*Transaction) error -} +type StoreType string + +const ( + MySQLStore StoreType = "MySQL" + RedisStore StoreType = "Redis" + MongoStore StoreType = "Mongo" + PostgreSQLStore StoreType = "PostgreSQL" +) type clients struct { MySQL *MySQL @@ -31,6 +33,39 @@ type clients struct { type Option func(*clients) +func (tr *Transaction) Conn(ctx context.Context) *gorm.DB { + if tr == nil { + // init transaction with default session + return Clients.MySQL.GetDB(ctx) + } + if tr.Session == nil { + // begin new with TxOpt + tr.Session = Clients.MySQL.GetDB(ctx).Begin(tr.TxOpt) + } + + return tr.Session +} + +func NewTransaction(ctx context.Context, store StoreType, opt *sql.TxOptions) *Transaction { + tr := &Transaction{TxOpt: opt} + + if store == MySQLStore { + session := Clients.MySQL.GetDB(ctx) + if opt != nil { + session = session.Begin(opt) + } + tr.Session = session + } else if store == RedisStore { + // TODO + } else if store == MongoStore { + // TODO + } else if store == PostgreSQLStore { + // TODO + } + + return tr +} + func (c *clients) close(ctx context.Context) { if c.MySQL != nil { c.MySQL.Close(ctx) @@ -43,11 +78,10 @@ func (c *clients) close(ctx context.Context) { func WithMySQL() Option { return func(c *clients) { if c.MySQL == nil { - if config.Config.MySQL != nil { - c.MySQL = NewMySQLClient() - } else { - panic("init repository fail, MySQL config is empty") + if config.Config.MySQL == nil { + panic("repository init fail, MySQL config is empty") } + c.MySQL = NewMySQLClient() } } } @@ -55,11 +89,10 @@ func WithMySQL() Option { func WithRedis() Option { return func(c *clients) { if c.Redis == nil { - if config.Config.Redis != nil { - c.Redis = NewRedisClient() - } else { - panic("init repository fail, Redis config is empty") + if config.Config.Redis == nil { + panic("repository init fail, Redis config is empty") } + c.Redis = NewRedisClient() } } } diff --git a/internal/adapter/repository/repository_test.go b/internal/adapter/repository/repository_test.go index 7c1bad6..fe4a1ff 100644 --- a/internal/adapter/repository/repository_test.go +++ b/internal/adapter/repository/repository_test.go @@ -2,8 +2,11 @@ package repository import ( "context" + "database/sql" "testing" + "github.com/stretchr/testify/assert" + "go-hexagonal/config" "go-hexagonal/util/log" ) @@ -20,9 +23,40 @@ func TestNewRepository(t *testing.T) { log.Init() Init(WithMySQL(), WithRedis()) - // assert.Nil(t, err) - // assert.NotNil(t, model) Close(ctx) - // redis +} + +func TestTransaction_Conn(t *testing.T) { + config.Init() + log.Init() + + Init(WithMySQL(), WithRedis()) + + t.Run("nil caller", func(t *testing.T) { + var tr *Transaction + db := tr.Conn(ctx) + assert.NotNil(t, db) + }) + + t.Run("with empty session", func(t *testing.T) { + tr := NewTransaction(ctx, + MySQLStore, + nil, + ) + tr.Session = nil + db := tr.Conn(ctx) + assert.NotNil(t, db) + }) + t.Run("with opt", func(t *testing.T) { + tr := NewTransaction(ctx, + MySQLStore, + &sql.TxOptions{ + Isolation: sql.LevelReadUncommitted, + ReadOnly: false, + }, + ) + db := tr.Conn(ctx) + assert.NotNil(t, db) + }) }