Skip to content

Commit

Permalink
♻️ Transactionの処理を簡潔にする (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikura-hamu authored Jan 14, 2025
1 parent 441d44b commit 3c4a091
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 42 deletions.
42 changes: 25 additions & 17 deletions server/repository/db/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,51 @@ import (
"github.com/traPtitech/piscon-portal-v2/server/repository"
)

type Repository struct {
db bob.DB
type repoDB struct {
// prevent direct access to bob.DB by beginning with an underscore
_db bob.DB
}

func NewRepository(db *sql.DB) *Repository {
return &Repository{
db: bob.NewDB(db),
type executorCtxKeyT struct{}

var executorCtxKey = executorCtxKeyT{}

func (db *repoDB) executor(ctx context.Context) bob.Executor {
if v := ctx.Value(executorCtxKey); v != nil {
exe, ok := v.(bob.Executor)
if ok {
return exe
}
}

return db._db
}

type txRepository struct {
tx bob.Tx
type Repository struct {
*repoDB
}

func newTxRepository(tx bob.Tx) *txRepository {
return &txRepository{
tx: tx,
func NewRepository(db *sql.DB) *Repository {
return &Repository{
repoDB: &repoDB{
_db: bob.NewDB(db),
},
}
}

func (r *Repository) Transaction(ctx context.Context, f func(ctx context.Context, r repository.Repository) error) error {
tx, err := r.db.BeginTx(ctx, nil)
tx, err := r._db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint errcheck

txRepo := newTxRepository(tx)
ctx = context.WithValue(ctx, executorCtxKey, tx)

err = f(ctx, txRepo)
err = f(ctx, r)
if err != nil {
return err
}

return tx.Commit()
}

func (t *txRepository) Transaction(ctx context.Context, f func(ctx context.Context, r repository.Repository) error) error {
return f(ctx, t)
}
18 changes: 3 additions & 15 deletions server/repository/db/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,15 @@ import (
)

func (r *Repository) FindSession(ctx context.Context, id string) (domain.Session, error) {
return findSession(ctx, r.db, id)
return findSession(ctx, r.executor(ctx), id)
}

func (r *Repository) CreateSession(ctx context.Context, session domain.Session) error {
return createSession(ctx, r.db, session)
return createSession(ctx, r.executor(ctx), session)
}

func (r *Repository) DeleteSession(ctx context.Context, id string) error {
return deleteSession(ctx, r.db, id)
}

func (t *txRepository) FindSession(ctx context.Context, id string) (domain.Session, error) {
return findSession(ctx, t.tx, id)
}

func (t *txRepository) CreateSession(ctx context.Context, session domain.Session) error {
return createSession(ctx, t.tx, session)
}

func (t *txRepository) DeleteSession(ctx context.Context, id string) error {
return deleteSession(ctx, t.tx, id)
return deleteSession(ctx, r.executor(ctx), id)
}

func findSession(ctx context.Context, executor bob.Executor, id string) (domain.Session, error) {
Expand Down
12 changes: 2 additions & 10 deletions server/repository/db/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,11 @@ import (
)

func (r *Repository) FindUser(ctx context.Context, id string) (domain.User, error) {
return findUser(ctx, r.db, id)
return findUser(ctx, r.executor(ctx), id)
}

func (r *Repository) CreateUser(ctx context.Context, user domain.User) error {
return createUser(ctx, r.db, user)
}

func (t *txRepository) FindUser(ctx context.Context, id string) (domain.User, error) {
return findUser(ctx, t.tx, id)
}

func (t *txRepository) CreateUser(ctx context.Context, user domain.User) error {
return createUser(ctx, t.tx, user)
return createUser(ctx, r.executor(ctx), user)
}

func findUser(ctx context.Context, executor bob.Executor, id string) (domain.User, error) {
Expand Down

0 comments on commit 3c4a091

Please sign in to comment.