diff --git a/README.md b/README.md index 9175436..11774c8 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ import ( ) func main() { + // Direct Initialization: // Initialize a Redis adapter and use it in a Casbin enforcer: a, _ := redisadapter.NewAdapter("tcp", "127.0.0.1:6379") // Your Redis network and address. @@ -29,6 +30,15 @@ func main() { // Use the following if you use Redis connections pool // pool := &redis.Pool{} // a, err := redisadapter.NewAdapterWithPool(pool) + + // Initialization with different user options: + // Use the following if you use Redis with passowrd like "123": + // a, err := redisadapter.NewAdapterWithOption(redisadapter.WithNetwork("tcp"), redisadapter.WithAddress("127.0.0.1:6379"), redisadapter.WithPassword("123")) + + // Use the following if you use Redis with username, password, and TLS option: + // var clientTLSConfig tls.Config + // ... + // a, err := redisadapter.NewAdapterWithOption(redisadapter.WithNetwork("tcp"), redisadapter.WithAddress("127.0.0.1:6379"), redisadapter.WithUsername("testAccount"), redisadapter.WithPassword("123456"), redisadapter.WithTls(&clientTLSConfig)) e := casbin.NewEnforcer("examples/rbac_model.conf", a) diff --git a/adapter.go b/adapter.go index 307e970..aa36e00 100644 --- a/adapter.go +++ b/adapter.go @@ -16,6 +16,7 @@ package redisadapter import ( "bytes" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -46,6 +47,7 @@ type Adapter struct { key string username string password string + tlsConfig *tls.Config conn redis.Conn isFiltered bool } @@ -106,7 +108,7 @@ func NewAdapterWithPool(pool *redis.Pool) (*Adapter, error) { type Option func(*Adapter) -func NewAdpaterWithOption(options ...Option) (*Adapter, error) { +func NewAdapterWithOption(options ...Option) (*Adapter, error) { a := &Adapter{} for _, option := range options { option(a) @@ -149,24 +151,31 @@ func WithKey(key string) Option { } } +func WithTls(tlsConfig *tls.Config) Option { + return func(a *Adapter) { + a.tlsConfig = tlsConfig + } +} + func (a *Adapter) open() error { //redis.Dial("tcp", "127.0.0.1:6379") + useTls := a.tlsConfig != nil if a.username != "" { - conn, err := redis.Dial(a.network, a.address, redis.DialUsername(a.username), redis.DialPassword(a.password)) + conn, err := redis.Dial(a.network, a.address, redis.DialUsername(a.username), redis.DialPassword(a.password), redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls)) if err != nil { return err } a.conn = conn } else if a.password == "" { - conn, err := redis.Dial(a.network, a.address) + conn, err := redis.Dial(a.network, a.address, redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls)) if err != nil { return err } a.conn = conn } else { - conn, err := redis.Dial(a.network, a.address, redis.DialPassword(a.password)) + conn, err := redis.Dial(a.network, a.address, redis.DialPassword(a.password), redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls)) if err != nil { return err } diff --git a/adapter_test.go b/adapter_test.go index 8ac3f99..0082d11 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -361,6 +361,20 @@ func TestAdapters(t *testing.T) { // Use the following if you use Redis with a account // a, err := NewAdapterWithUser("tcp", "127.0.0.1:6379", "testaccount", "userpass") + testSaveLoad(t, a) + testAutoSave(t, a) + testFilteredPolicy(t, a) + testAddPolicies(t, a) + testRemovePolicies(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) +} + +func TestAdapterWithOption(t *testing.T) { + a, _ := NewAdapterWithOption(WithNetwork("tcp"), WithAddress("127.0.0.1:6379")) + // User the following if use TLS to connect to redis + // var clientTLSConfig tls.Config + // a, err := NewAdapterWithOption(WithTls(&clientTLSConfig)) testSaveLoad(t, a) testAutoSave(t, a)