diff --git a/rpc/plugins/push/prosumer.go b/rpc/plugins/push/prosumer.go index 7a9d312..8b28c46 100644 --- a/rpc/plugins/push/prosumer.go +++ b/rpc/plugins/push/prosumer.go @@ -56,6 +56,24 @@ func NewProsumer(client *core.Client, id ...string) *Prosumer { return p } +func (p *Prosumer) onError(err error) { + if p.OnError != nil { + p.OnError(err) + } +} + +func (p *Prosumer) onSubscribe(topic string) { + if p.OnSubscribe != nil { + p.OnSubscribe(topic) + } +} + +func (p *Prosumer) onUnsubscribe(topic string) { + if p.OnUnsubscribe != nil { + p.OnUnsubscribe(topic) + } +} + func (p *Prosumer) Client() *core.Client { return p.client } @@ -99,14 +117,19 @@ func (p *Prosumer) call(callback Callback, message Message) { default: v := reflect.ValueOf(callback) t := v.Type() - switch t.NumIn() { - case 1: - if data, err := io.Convert(message.Data, t.In(0)); err != nil { - v.Call([]reflect.Value{reflect.ValueOf(data)}) + if n := t.NumIn(); n >= 1 { + data, err := io.Convert(message.Data, t.In(0)) + if err != nil { + p.onError(err) + return } - case 2: - if data, err := io.Convert(message.Data, t.In(0)); err != nil { + switch n { + case 1: + v.Call([]reflect.Value{reflect.ValueOf(data)}) + case 2: v.Call([]reflect.Value{reflect.ValueOf(data), reflect.ValueOf(message.From)}) + default: + panic("invalid callback: " + t.String()) } } } @@ -120,9 +143,7 @@ func (p *Prosumer) message() { if p.RetryInterval != 0 { <-time.After(p.RetryInterval) } - if p.OnError != nil { - p.OnError(err) - } + p.onError(err) } continue } @@ -138,9 +159,7 @@ func (p *Prosumer) Subscribe(topic string, callback Callback) (result bool, err p.callbacks.Store(topic, callback) result, err = p.proxy.subscribe(topic) go p.message() - if p.OnSubscribe != nil { - p.OnSubscribe(topic) - } + p.onSubscribe(topic) } return } @@ -149,9 +168,7 @@ func (p *Prosumer) Unsubscribe(topic string) (result bool, err error) { if p.ID() != "" { result, err = p.proxy.unsubscribe(topic) p.callbacks.Delete(topic) - if p.OnUnsubscribe != nil { - p.OnUnsubscribe(topic) - } + p.onUnsubscribe(topic) } return } diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index 04b0e21..dfa8e32 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -779,8 +779,11 @@ func TestPush(t *testing.T) { time.Sleep(time.Millisecond * 5) client1 := rpc.NewClient("tcp://127.0.0.1/") - client1.Use(log.Plugin.IOHandler) + //client1.Use(log.Plugin.IOHandler) prosumer1 := push.NewProsumer(client1, "1") + prosumer1.OnError = func(e error) { + fmt.Println(e.Error()) + } prosumer1.OnSubscribe = func(topic string) { fmt.Println(topic, "is subscribed.") } @@ -788,10 +791,10 @@ func TestPush(t *testing.T) { fmt.Println(topic, "is unsubscribed.") } client2 := rpc.NewClient("tcp://127.0.0.1/") - client2.Use(log.Plugin.IOHandler) + //client2.Use(log.Plugin.IOHandler) prosumer2 := push.NewProsumer(client2, "2") - prosumer1.Subscribe("test", func(data string) { - fmt.Println(data) + prosumer1.Subscribe("test", func(data int, from string) { + fmt.Printf("%v from %v\n", data, from) }) prosumer1.Subscribe("test2", func(message push.Message) { fmt.Println(message)