Skip to content

Commit

Permalink
feat: support std error
Browse files Browse the repository at this point in the history
fix: fix concurrent map writes

BREAKING CHANGE: change the error with std error which might break the origin err handler.
  • Loading branch information
asjdf committed Aug 21, 2023
1 parent fb9a084 commit 9f9b1b3
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 20 deletions.
8 changes: 8 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package inject

import "errors"

var (
ErrValueNotFound = errors.New("value not found")
ErrValueCanNotSet = errors.New("value can not set")
)
15 changes: 15 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package inject

import (
"errors"
"fmt"
"reflect"
"testing"
)

func TestError(t *testing.T) {
err := fmt.Errorf("%w: %v", ErrValueNotFound, reflect.TypeOf(""))
expect(t, errors.Is(err, ErrValueNotFound), true)
err = fmt.Errorf("%w: %v", ErrValueCanNotSet, reflect.TypeOf(""))
expect(t, errors.Is(err, ErrValueCanNotSet), true)
}
25 changes: 17 additions & 8 deletions inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package inject
import (
"fmt"
"reflect"
"sync"
)

// Injector represents an interface for mapping and injecting dependencies into
Expand Down Expand Up @@ -79,10 +80,11 @@ var _ Injector = (*injector)(nil)
type injector struct {
values map[reflect.Type]reflect.Value
parent Injector
mu sync.RWMutex
}

// InterfaceOf dereferences a pointer to an Interface type. It panics if value
// is not an pointer to an interface.
// is not a pointer to an interface.
func InterfaceOf(value interface{}) reflect.Type {
t := reflect.TypeOf(value)
for t.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -127,7 +129,7 @@ func (inj *injector) fastInvoke(f FastInvoker, t reflect.Type, numIn int) ([]ref
argType = t.In(i)
val = inj.Value(argType)
if !val.IsValid() {
return nil, fmt.Errorf("value not found for type %v", argType)
return nil, fmt.Errorf("%w: %v", ErrValueNotFound, argType)
}

in[i] = val.Interface()
Expand All @@ -146,7 +148,7 @@ func (inj *injector) callInvoke(f interface{}, t reflect.Type, numIn int) ([]ref
argType = t.In(i)
val = inj.Value(argType)
if !val.IsValid() {
return nil, fmt.Errorf("value not found for type %v", argType)
return nil, fmt.Errorf("%w: %v", ErrValueNotFound, argType)
}

in[i] = val
Expand Down Expand Up @@ -176,7 +178,7 @@ func (inj *injector) Apply(val interface{}) error {
ft := f.Type()
v := inj.Value(ft)
if !v.IsValid() {
return fmt.Errorf("value not found for type %v", ft)
return fmt.Errorf("%w: %v", ErrValueNotFound, ft)
}

f.Set(v)
Expand All @@ -187,24 +189,32 @@ func (inj *injector) Apply(val interface{}) error {
}

func (inj *injector) Map(values ...interface{}) TypeMapper {
inj.mu.RLock()
for _, val := range values {
inj.values[reflect.TypeOf(val)] = reflect.ValueOf(val)
}
inj.mu.RUnlock()
return inj
}

func (inj *injector) MapTo(val, ifacePtr interface{}) TypeMapper {
inj.mu.RLock()
inj.values[InterfaceOf(ifacePtr)] = reflect.ValueOf(val)
inj.mu.RUnlock()
return inj
}

func (inj *injector) Set(typ reflect.Type, val reflect.Value) TypeMapper {
inj.mu.Lock()
inj.values[typ] = val
inj.mu.Unlock()
return inj
}

func (inj *injector) Value(t reflect.Type) reflect.Value {
inj.mu.RLock()
val := inj.values[t]
inj.mu.RUnlock()

if val.IsValid() {
return val
Expand Down Expand Up @@ -233,16 +243,15 @@ func (inj *injector) Load(val interface{}) error {
valType := reflect.TypeOf(val)
value := inj.Value(valType)
if !value.IsValid() {

return fmt.Errorf("value not found for type %v", valType)
return fmt.Errorf("%w: %v", ErrValueNotFound, valType)
}
v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("value not a pointer for type %v", valType)
return fmt.Errorf("%w: %v", ErrValueCanNotSet, valType)
}
v = v.Elem()
if !v.CanSet() {
return fmt.Errorf("value not settable for type %v", valType)
return fmt.Errorf("%w: %v", ErrValueCanNotSet, valType)
}
v.Set(value.Elem())
return nil
Expand Down
56 changes: 44 additions & 12 deletions inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package inject
import (
"fmt"
"reflect"
"sync"
"testing"
"unsafe"
)

type specialString interface{}
Expand All @@ -23,9 +25,9 @@ func (g *greeter) String() string {
}

/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
func expect(t testing.TB, actual interface{}, expect interface{}) {
if actual != expect {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", expect, reflect.TypeOf(expect), actual, reflect.TypeOf(actual))
}
}

Expand All @@ -41,6 +43,11 @@ func (myFastInvoker) Invoke([]interface{}) ([]reflect.Value, error) {
return nil, nil
}

func TestInjectorSize(t *testing.T) {
// prevent unnecessary memory usage increases
expect(t, unsafe.Alignof(injector{}), uintptr(8))
}

func BenchmarkNew(b *testing.B) {
b.ReportAllocs()
var j Injector
Expand Down Expand Up @@ -140,6 +147,24 @@ func TestInjector_InterfaceOf(t *testing.T) {
InterfaceOf((*testing.T)(nil))
}

func TestInjector_Map(t *testing.T) {
inj := New()

g := &greeter{"Jeremy"}
inj.Map(g)

expect(t, inj.Value(InterfaceOf((*fmt.Stringer)(nil))).IsValid(), true)
}

func BenchmarkInjector_Map(b *testing.B) {
b.ReportAllocs()
inj := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
inj.Map("Jeremy")
}
}

func TestInjector_Set(t *testing.T) {
inj := New()

Expand Down Expand Up @@ -204,15 +229,6 @@ func TestInjector_SetParent(t *testing.T) {
expect(t, inj2.Value(InterfaceOf((*specialString)(nil))).IsValid(), true)
}

func TestInjector_Implementors(t *testing.T) {
inj := New()

g := &greeter{"Jeremy"}
inj.Map(g)

expect(t, inj.Value(InterfaceOf((*fmt.Stringer)(nil))).IsValid(), true)
}

func TestIsFastInvoker(t *testing.T) {
expect(t, IsFastInvoker(myFastInvoker(nil)), true)
}
Expand Down Expand Up @@ -249,3 +265,19 @@ func BenchmarkInjector_FastInvoke(b *testing.B) {
_, _ = inj.Invoke(fn)
}
}

func TestConcurrentMap(t *testing.T) {
inj := New()
var trigger, wg sync.WaitGroup
trigger.Add(1)
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
trigger.Wait()
inj.Map("")
wg.Done()
}()
}
trigger.Done()
wg.Done()
}

0 comments on commit 9f9b1b3

Please sign in to comment.