diff --git a/rpc/plugins/reverse/caller.go b/rpc/plugins/reverse/caller.go new file mode 100644 index 0000000..2103444 --- /dev/null +++ b/rpc/plugins/reverse/caller.go @@ -0,0 +1,338 @@ +/*--------------------------------------------------------*\ +| | +| hprose | +| | +| Official WebSite: https://hprose.com | +| | +| rpc/plugins/reverse/caller.go | +| | +| LastModified: May 19, 2021 | +| Author: Ma Bingyao | +| | +\*________________________________________________________*/ + +package reverse + +import ( + "context" + "errors" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/hprose/hprose-golang/v3/io" + "github.com/hprose/hprose-golang/v3/rpc/core" + "github.com/modern-go/reflect2" + cmap "github.com/orcaman/concurrent-map" +) + +type call [3]interface{} + +func newCall(index int, name string, args []interface{}) (c call) { + c[0] = index + c[1] = name + c[2] = args + return +} + +func (c call) Value() (index int, name string, args []interface{}) { + return c[0].(int), c[1].(string), c[2].([]interface{}) +} + +type callCache struct { + c []call + sync.Mutex +} + +func (cc *callCache) Append(c call) { + cc.Lock() + defer cc.Unlock() + cc.c = append(cc.c, c) +} + +func (cc *callCache) Delete(index int) { + cc.Lock() + defer cc.Unlock() + for i := 0; i < len(cc.c); i++ { + if cc.c[i][0].(int) == index { + cc.c = append(cc.c[:i], cc.c[i+1:]...) + return + } + } +} + +func (cc *callCache) Take() (calls []call) { + cc.Lock() + defer cc.Unlock() + calls = cc.c + cc.c = nil + return +} + +type returnValue [3]interface{} + +func newReturnValue(index int, result interface{}, err string) (r returnValue) { + r[0] = index + r[1] = result + r[2] = err + return +} + +func (r returnValue) Index() int { + return r[0].(int) +} + +func (r returnValue) Value(returnType []reflect.Type) ([]interface{}, error) { + err := r[2].(string) + if err != "" { + return nil, errors.New(err) + } + n := len(returnType) + switch n { + case 0: + return nil, nil + case 1: + if result, err := io.Convert(r[1], returnType[0]); err != nil { + return nil, err + } else { + return []interface{}{result}, nil + } + default: + results := make([]interface{}, n) + values := r[1].([]interface{}) + count := len(values) + for i := 0; i < n && i < count; i++ { + if result, err := io.Convert(values[i], returnType[i]); err != nil { + return nil, err + } else { + results[i] = result + } + } + for i := count; i < n; i++ { + t := reflect2.Type2(returnType[i]) + results[i] = t.Indirect(t.New()) + } + return results, nil + } +} + +type resultMap struct { + results map[int]chan returnValue + sync.Mutex +} + +func newResultMap() *resultMap { + return &resultMap{ + results: make(map[int]chan returnValue), + } +} + +func (m *resultMap) GetAndDelete(index int) chan returnValue { + m.Lock() + defer m.Unlock() + if result, ok := m.results[index]; ok { + delete(m.results, index) + return result + } + return nil +} + +func (m *resultMap) Delete(index int) { + m.Lock() + defer m.Unlock() + delete(m.results, index) +} + +func (m *resultMap) Set(index int, result chan returnValue) { + m.Lock() + defer m.Unlock() + m.results[index] = result +} + +var ( + emptyArgs = make([]interface{}, 0) + emptyCall = make([]call, 0) +) + +type Caller struct { + *core.Service + HeartBeat time.Duration + Timeout time.Duration + calls cmap.ConcurrentMap + results cmap.ConcurrentMap + responders cmap.ConcurrentMap + onlines cmap.ConcurrentMap + counter int32 +} + +func NewCaller(service *core.Service) *Caller { + caller := &Caller{ + Service: service, + HeartBeat: time.Minute * 2, + Timeout: time.Second * 30, + calls: cmap.New(), + results: cmap.New(), + responders: cmap.New(), + onlines: cmap.New(), + } + service.Use(caller.handler). + AddFunction(caller.close, "!!"). + AddFunction(caller.begin, "!"). + AddFunction(caller.end, "=") + return caller +} + +func (c *Caller) ID(ctx context.Context) (id string) { + if id = core.GetServiceContext(ctx).RequestHeaders().GetString("id"); id == "" { + panic("client unique id not found") + } + return +} + +func (c *Caller) send(id string, responder chan []call) bool { + if calls, ok := c.calls.Get(id); ok { + calls := calls.(*callCache).Take() + if len(calls) == 0 { + return false + } + responder <- calls + return true + } + return false +} + +func (c *Caller) response(id string) { + if responder, ok := c.responders.Pop(id); ok { + responder := responder.(chan []call) + if !c.send(id, responder) { + if c.responders.SetIfAbsent(id, responder) { + responder <- nil + } + } + } +} + +func (c *Caller) stop(ctx context.Context) string { + id := c.ID(ctx) + if responder, ok := c.responders.Pop(id); ok { + responder.(chan []call) <- nil + } + return id +} + +func (c *Caller) close(ctx context.Context) { + id := c.stop(ctx) + c.onlines.Remove(id) +} + +func (c *Caller) begin(ctx context.Context) []call { + id := c.stop(ctx) + c.onlines.Set(id, true) + responder := make(chan []call, 1) + if !c.send(id, responder) { + c.responders.Upsert(id, responder, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} { + if exist { + valueInMap.(chan []call) <- nil + } + return newValue + }) + if c.HeartBeat > 0 { + ctx, cancel := context.WithTimeout(ctx, c.HeartBeat) + defer cancel() + select { + case <-ctx.Done(): + responder <- emptyCall + case result := <-responder: + return result + } + } + } + return <-responder +} + +func (c *Caller) end(ctx context.Context, results []returnValue) { + id := c.ID(ctx) + for _, rv := range results { + if r, ok := c.results.Get(id); ok { + if value := r.(*resultMap).GetAndDelete(rv.Index()); value != nil { + value <- rv + } + } + } +} + +func (c *Caller) Invoke(id string, name string, args []interface{}, returnType ...reflect.Type) ([]interface{}, error) { + return c.InvokeContext(context.Background(), id, name, args, returnType...) +} + +func (c *Caller) InvokeContext(ctx context.Context, id string, name string, args []interface{}, returnType ...reflect.Type) ([]interface{}, error) { + if args == nil { + args = emptyArgs + } + index := int(atomic.AddInt32(&c.counter, 1) & 0x7fffffff) + var calls *callCache + if cc, ok := c.calls.Get(id); ok { + calls = cc.(*callCache) + } else { + calls = new(callCache) + if !c.calls.SetIfAbsent(id, calls) { + cc, _ := c.calls.Get(id) + calls = cc.(*callCache) + } + } + calls.Append(newCall(index, name, args)) + var results *resultMap + if rm, ok := c.results.Get(id); ok { + results = rm.(*resultMap) + } else { + results = newResultMap() + if !c.results.SetIfAbsent(id, results) { + rm, _ := c.results.Get(id) + results = rm.(*resultMap) + } + } + result := make(chan returnValue, 1) + results.Set(index, result) + c.response(id) + if c.Timeout > 0 { + ctx, cancel := context.WithTimeout(ctx, c.HeartBeat) + defer cancel() + select { + case <-ctx.Done(): + calls.Delete(index) + results.Delete(index) + return nil, core.ErrTimeout + case result := <-result: + return result.Value(returnType) + } + } + return (<-result).Value(returnType) +} + +func (c *Caller) UseService(remoteService interface{}, id string, namespace ...string) { + ns := "" + if len(namespace) > 0 { + ns = namespace[0] + } + core.Proxy.Build(remoteService, invocation{caller: c, id: id, namespace: ns}.Invoke) +} + +func (c *Caller) Exists(id string) bool { + return c.onlines.Has(id) +} + +func (c *Caller) IdList() []string { + return c.onlines.Keys() +} + +func (c *Caller) handler(ctx context.Context, name string, args []interface{}, next core.NextInvokeHandler) (result []interface{}, err error) { + core.GetServiceContext(ctx).Items().Set("caller", c) + return next(ctx, name, args) +} + +func UseService(ctx context.Context, remoteService interface{}, namespace ...string) *Caller { + caller := core.GetServiceContext(ctx).Items().GetInterface("caller").(*Caller) + caller.UseService(remoteService, caller.ID(ctx), namespace...) + return caller +} diff --git a/rpc/plugins/reverse/invocation.go b/rpc/plugins/reverse/invocation.go new file mode 100644 index 0000000..11594f9 --- /dev/null +++ b/rpc/plugins/reverse/invocation.go @@ -0,0 +1,73 @@ +/*--------------------------------------------------------*\ +| | +| hprose | +| | +| Official WebSite: https://hprose.com | +| | +| rpc/plugins/reverse/invocation.go | +| | +| LastModified: May 19, 2021 | +| Author: Ma Bingyao | +| | +\*________________________________________________________*/ + +package reverse + +import ( + "context" + "reflect" + "strings" + + "github.com/hprose/hprose-golang/v3/rpc/core" +) + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +type invocation struct { + caller *Caller + id string + namespace string +} + +func (i invocation) Invoke(proxy interface{}, method reflect.StructField, name string, args []interface{}) (results []interface{}, err error) { + var rpcContext core.Context + var ctx context.Context + if len(args) > 0 { + switch c := args[0].(type) { + case core.Context: + rpcContext = c + args = args[1:] + case context.Context: + ctx = c + rpcContext, _ = core.FromContext(ctx) + args = args[1:] + } + } + if ctx == nil { + ctx = context.Background() + } + if rpcContext == nil { + rpcContext = core.NewContext() + } + if _, ok := core.FromContext(ctx); !ok { + ctx = core.WithContext(ctx, rpcContext) + } + tagParser := core.ParseTag(nil, method.Tag) + if tagParser.Name != "" { + name = tagParser.Name + } + name = strings.Replace(name, ".", "_", -1) //nolint:gocritic + if i.namespace != "" { + name = i.namespace + "_" + name + } + t := method.Type + n := t.NumOut() + returnType := make([]reflect.Type, n) + for i := 0; i < n; i++ { + returnType[i] = t.Out(i) + } + if n > 0 && returnType[n-1] == errorType { + returnType = returnType[:n-1] + } + return i.caller.InvokeContext(ctx, i.id, name, args, returnType...) +} diff --git a/rpc/plugins/reverse/provider.go b/rpc/plugins/reverse/provider.go new file mode 100644 index 0000000..482c6da --- /dev/null +++ b/rpc/plugins/reverse/provider.go @@ -0,0 +1,348 @@ +/*--------------------------------------------------------*\ +| | +| hprose | +| | +| Official WebSite: https://hprose.com | +| | +| rpc/plugins/reverse/provider.go | +| | +| LastModified: May 19, 2021 | +| Author: Ma Bingyao | +| | +\*________________________________________________________*/ + +package reverse + +import ( + "context" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/hprose/hprose-golang/v3/io" + "github.com/hprose/hprose-golang/v3/rpc/core" +) + +type contextMissingMethod = func(ctx context.Context, name string, args []interface{}) (result []interface{}, err error) +type missingMethod = func(name string, args []interface{}) (result []interface{}, err error) + +type ProviderContext struct { + core.Context + client *core.Client + method core.Method +} + +func NewProviderContext(client *core.Client, method core.Method) *ProviderContext { + return &ProviderContext{ + client: client, + method: method, + } +} + +func (c *ProviderContext) Client() *core.Client { + return c.client +} + +func (c *ProviderContext) Method() core.Method { + return c.method +} + +func (c *ProviderContext) Clone() core.Context { + return &ProviderContext{ + c.Context.Clone(), + c.client, + c.method, + } +} + +// GetProviderContext returns the *reverse.ProviderContext bound to the context. +func GetProviderContext(ctx context.Context) *ProviderContext { + if c, ok := core.FromContext(ctx); ok { + return c.(*ProviderContext) + } + return nil +} + +type Provider struct { + client *core.Client + proxy provider + invokeManager core.PluginManager + methodManager core.MethodManager + closed int32 + RetryInterval time.Duration + OnError func(error) + Debug bool +} + +type provider struct { + close func() error `name:"!!"` + begin func() ([]call, error) `name:"!"` + end func(results []returnValue) error `name:"="` +} + +func NewProvider(client *core.Client, id ...string) *Provider { + p := &Provider{ + client: client, + RetryInterval: time.Second, + closed: 1, + } + if len(id) > 0 && id[0] != "" { + p.SetID(id[0]) + } + p.client.UseService(&p.proxy) + p.invokeManager = core.NewInvokeManager(p.Execute) + p.methodManager = core.NewMethodManager() + p.AddFunction(p.methodManager.Names, "~") + return p +} + +func (p *Provider) onError(err error) { + if p.OnError != nil { + p.OnError(err) + } +} + +func (p *Provider) Client() *core.Client { + return p.client +} + +func (p *Provider) ID() (id string) { + if id = p.client.RequestHeaders().GetString("id"); id == "" { + panic("client unique id not found") + } + return +} + +func (p *Provider) SetID(id string) { + p.client.RequestHeaders().Set("id", id) +} + +func (p *Provider) Execute(ctx context.Context, name string, args []interface{}) (result []interface{}, err error) { + method := GetProviderContext(ctx).method + if method.Missing() { + if method.PassContext() { + return method.(interface{}).(contextMissingMethod)(ctx, name, args) + } + return method.(interface{}).(missingMethod)(name, args) + } + n := len(args) + var in []reflect.Value + if method.PassContext() { + in = make([]reflect.Value, n+1) + in[0] = reflect.ValueOf(ctx) + for i := 0; i < n; i++ { + in[i+1] = reflect.ValueOf(args[i]) + } + } else { + in = make([]reflect.Value, n) + for i := 0; i < n; i++ { + in[i] = reflect.ValueOf(args[i]) + } + } + f := method.Func() + out := f.Call(in) + n = len(out) + if method.ReturnError() { + if !out[n-1].IsNil() { + err = out[n-1].Interface().(error) + } + out = out[:n-1] + n-- + } + for i := 0; i < n; i++ { + result = append(result, out[i].Interface()) + } + return +} + +func (p *Provider) process(c call) (rv returnValue) { + index, name, args := c.Value() + defer func() { + if e := recover(); e != nil { + err := core.NewPanicError(e) + if p.Debug { + rv = newReturnValue(index, nil, err.String()) + } else { + rv = newReturnValue(index, nil, err.Error()) + } + } + }() + method := p.Get(name) + if method == nil { + return newReturnValue(index, nil, "Can't find this method "+name+"().") + } + if !method.Missing() { + count := len(args) + parameters := method.Parameters() + paramTypes := make([]reflect.Type, count) + if method.Func().Type().IsVariadic() { + n := len(parameters) + copy(paramTypes, parameters[:n-1]) + for i := n - 1; i < count; i++ { + paramTypes[i] = parameters[n-1].Elem() + } + } else { + copy(paramTypes, parameters) + } + for i, t := range paramTypes { + if arg, err := io.Convert(args[i], t); err != nil { + return newReturnValue(index, nil, err.Error()) + } else { + args[i] = arg + } + } + } + ctx := core.WithContext(context.Background(), NewProviderContext(p.client, method)) + results, err := p.invokeManager.Handler().(core.NextInvokeHandler)(ctx, name, args) + var result interface{} + switch len(results) { + case 0: + result = nil + case 1: + result = results[0] + default: + result = results + } + if err != nil { + return newReturnValue(index, result, err.Error()) + } + return newReturnValue(index, result, "") +} + +func (p *Provider) dispatch(calls []call) { + n := len(calls) + results := make([]returnValue, n) + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + results[i] = p.process(calls[i]) + wg.Done() + }(i) + } + wg.Wait() + for { + if err := p.proxy.end(results); err != nil { + if !core.IsTimeoutError(err) { + if p.RetryInterval != 0 { + <-time.After(p.RetryInterval) + } + p.onError(err) + } + continue + } + return + } +} + +func (p *Provider) Listen() { + if !atomic.CompareAndSwapInt32(&p.closed, 1, 0) { + return + } + for atomic.LoadInt32(&p.closed) == 0 { + calls, err := p.proxy.begin() + if err != nil { + if !core.IsTimeoutError(err) { + if p.RetryInterval != 0 { + <-time.After(p.RetryInterval) + } + p.onError(err) + } + continue + } + if calls == nil { + return + } + go p.dispatch(calls) + } + atomic.StoreInt32(&p.closed, 1) +} + +func (p *Provider) Close() error { + if atomic.CompareAndSwapInt32(&p.closed, 0, 1) { + return p.proxy.close() + } + return core.ErrClosed +} + +// Use plugin handlers. +func (p *Provider) Use(handler ...core.PluginHandler) *Provider { + invokeHandlers, _ := core.SeparatePluginHandlers(handler) + if len(invokeHandlers) > 0 { + p.invokeManager.Use(invokeHandlers...) + } + return p +} + +// Unuse plugin handlers. +func (p *Provider) Unuse(handler ...core.PluginHandler) *Provider { + invokeHandlers, _ := core.SeparatePluginHandlers(handler) + if len(invokeHandlers) > 0 { + p.invokeManager.Unuse(invokeHandlers...) + } + return p +} + +// Get returns the published method by name. +func (p *Provider) Get(name string) core.Method { + return p.methodManager.Get(name) +} + +// Remove is used for unpublishing method by the specified name. +func (p *Provider) Remove(name string) *Provider { + p.methodManager.Remove(name) + return p +} + +// Add is used for publishing the method. +func (p *Provider) Add(method core.Method) *Provider { + p.methodManager.Add(method) + return p +} + +// AddFunction is used for publishing function f with alias. +func (p *Provider) AddFunction(f interface{}, alias ...string) *Provider { + p.methodManager.AddFunction(f, alias...) + return p +} + +// AddMethod is used for publishing method named name on target with alias. +func (p *Provider) AddMethod(name string, target interface{}, alias ...string) *Provider { + p.methodManager.AddMethod(name, target, alias...) + return p +} + +// AddMethods is used for publishing methods named names on target with namespace. +func (p *Provider) AddMethods(names []string, target interface{}, namespace ...string) *Provider { + p.methodManager.AddMethods(names, target, namespace...) + return p +} + +// AddInstanceMethods is used for publishing all the public methods and func fields with namespace. +func (p *Provider) AddInstanceMethods(target interface{}, namespace ...string) *Provider { + p.methodManager.AddInstanceMethods(target, namespace...) + return p +} + +// AddAllMethods will publish all methods and non-nil function fields on the +// obj self and on its anonymous or non-anonymous struct fields (or pointer to +// pointer ... to pointer struct fields). This is a recursive operation. +// So it's a pit, if you do not know what you are doing, do not step on. +func (p *Provider) AddAllMethods(target interface{}, namespace ...string) *Provider { + p.methodManager.AddAllMethods(target, namespace...) + return p +} + +// AddMissingMethod is used for publishing a method, +// all methods not explicitly published will be redirected to this method. +func (p *Provider) AddMissingMethod(f interface{}) *Provider { + p.methodManager.AddMissingMethod(f) + return p +} + +// AddNetRPCMethods is used for publishing methods defined for net/rpc. +func (p *Provider) AddNetRPCMethods(rcvr interface{}, namespace ...string) *Provider { + p.methodManager.AddNetRPCMethods(rcvr, namespace...) + return p +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index ba296b8..0a141fb 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -6,7 +6,7 @@ | | | rpc/rpc_test.go | | | -| LastModified: May 12, 2021 | +| LastModified: May 19, 2021 | | Author: Ma Bingyao | | | \*________________________________________________________*/ @@ -32,6 +32,7 @@ import ( "github.com/hprose/hprose-golang/v3/rpc" "github.com/hprose/hprose-golang/v3/rpc/plugins/log" "github.com/hprose/hprose-golang/v3/rpc/plugins/push" + "github.com/hprose/hprose-golang/v3/rpc/plugins/reverse" "github.com/stretchr/testify/assert" ) @@ -841,3 +842,35 @@ func TestPush(t *testing.T) { assert.NoError(t, err) server.Close() } + +func TestReverseInvoke(t *testing.T) { + service := rpc.NewService() + caller := reverse.NewCaller(service) + server, err := net.Listen("tcp", "127.0.0.1:8412") + assert.NoError(t, err) + err = service.Bind(server) + assert.NoError(t, err) + + time.Sleep(time.Millisecond * 5) + + client := rpc.NewClient("tcp://127.0.0.1/") + client.Use(log.Plugin) + provider := reverse.NewProvider(client, "1") + provider.Debug = true + provider.AddFunction(func(name string) string { + return "hello " + name + }, "hello") + go provider.Listen() + + time.Sleep(time.Millisecond * 100) + + var proxy struct { + Hello func(name string) (string, error) + } + caller.UseService(&proxy, "1") + result, err := proxy.Hello("world") + assert.Equal(t, "hello world", result) + assert.NoError(t, err) + provider.Close() + server.Close() +}