Skip to content

Commit

Permalink
feat: trace resolver mode (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
a631807682 authored Dec 4, 2022
1 parent 4a7eece commit 58a1b2a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 12 deletions.
18 changes: 12 additions & 6 deletions dbresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ type DBResolver struct {
}

type Config struct {
Sources []gorm.Dialector
Replicas []gorm.Dialector
Policy Policy
datas []interface{}
Sources []gorm.Dialector
Replicas []gorm.Dialector
Policy Policy
datas []interface{}
TraceResolverMode bool
}

func Register(config Config, datas ...interface{}) *DBResolver {
Expand Down Expand Up @@ -76,8 +77,9 @@ func (dr *DBResolver) compileConfig(config Config) (err error) {
var (
connPool = dr.DB.Config.ConnPool
r = resolver{
dbResolver: dr,
policy: config.Policy,
dbResolver: dr,
policy: config.Policy,
traceResolverMode: config.TraceResolverMode,
}
)

Expand Down Expand Up @@ -122,6 +124,10 @@ func (dr *DBResolver) compileConfig(config Config) (err error) {
}
}

if config.TraceResolverMode {
dr.Logger = NewResolverModeLogger(dr.Logger)
}

return nil
}

Expand Down
13 changes: 11 additions & 2 deletions dbresolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package dbresolver_test

import (
"fmt"
"os"
"testing"

"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/plugin/dbresolver"
)

Expand Down Expand Up @@ -50,16 +52,23 @@ func TestDBResolver(t *testing.T) {
if err != nil {
t.Fatalf("failed to connect db, got error: %v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
DB.Logger = DB.Logger.LogMode(logger.Info)
} else if debug == "false" {
DB.Logger = DB.Logger.LogMode(logger.Silent)
}

if err := DB.Use(dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{
mysql.Open("gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True&loc=Local"),
mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local"),
},
TraceResolverMode: true,
}).Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")},
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")},
TraceResolverMode: true,
}, "users", &Product{}).SetMaxOpenConns(5)); err != nil {
t.Fatalf("failed to use plugin, got error: %v", err)
}
Expand Down
54 changes: 54 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dbresolver

import (
"context"
"fmt"
"time"

"gorm.io/gorm"
"gorm.io/gorm/logger"
)

type ResolverModeKey string
type ResolverMode string

const resolverModeKey ResolverModeKey = "dbresolver:resolver_mode_key"
const (
ResolverModeSource ResolverMode = "source"
ResolverModeReplica ResolverMode = "replica"
)

type resolverModeLogger struct {
logger.Interface
}

func (l resolverModeLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
var splitFn = func() (sql string, rowsAffected int64) {
sql, rowsAffected = fc()
op := ctx.Value(resolverModeKey)
if op != nil {
sql = fmt.Sprintf("[%s] %s", op, sql)
return
}

// the situation that dbresolver does not handle
// such as transactions, or some resolvers do not enable MarkResolverMode.
return
}
l.Interface.Trace(ctx, begin, splitFn, err)
}

func NewResolverModeLogger(l logger.Interface) logger.Interface {
if _, ok := l.(resolverModeLogger); ok {
return l
}
return resolverModeLogger{
Interface: l,
}
}

func markStmtResolverMode(stmt *gorm.Statement, mode ResolverMode) {
if _, ok := stmt.Logger.(resolverModeLogger); ok {
stmt.Context = context.WithValue(stmt.Context, resolverModeKey, mode)
}
}
18 changes: 14 additions & 4 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
)

type resolver struct {
sources []gorm.ConnPool
replicas []gorm.ConnPool
policy Policy
dbResolver *DBResolver
sources []gorm.ConnPool
replicas []gorm.ConnPool
policy Policy
dbResolver *DBResolver
traceResolverMode bool
}

func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.ConnPool) {
Expand All @@ -18,10 +19,19 @@ func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.Co
} else {
connPool = r.policy.Resolve(r.replicas)
}
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeReplica)
}
} else if len(r.sources) == 1 {
connPool = r.sources[0]
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeSource)
}
} else {
connPool = r.policy.Resolve(r.sources)
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeSource)
}
}

if stmt.DB.PrepareStmt {
Expand Down

0 comments on commit 58a1b2a

Please sign in to comment.