From c3326bd7cf986c67f1a96d3b5668523b98727f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 8 May 2024 00:37:27 +0900 Subject: [PATCH] optimize(client): impl. syncx on resultStore --- client/base.go | 30 +++++++++++++++------------- client/client.go | 2 +- client/message.go | 4 +--- client/resultStore.go | 43 +++++++++++++++++++---------------------- go.mod | 2 +- packets/wtlogin/oicq.go | 2 +- utils/sign.go | 6 +++--- 7 files changed, 44 insertions(+), 45 deletions(-) diff --git a/client/base.go b/client/base.go index 68fcca5b..01a71481 100644 --- a/client/base.go +++ b/client/base.go @@ -43,7 +43,7 @@ type QQClient struct { appInfo *info.AppInfo deviceInfo *info.DeviceInfo sig *info.SigInfo - signProvider func(string, int, []byte) map[string]string + signProvider func(string, uint32, []byte) map[string]string pushStore chan *wtlogin.SSOPacket @@ -88,7 +88,7 @@ func (c *QQClient) SendOidbPacketAndWait(pkt *oidb.OidbPacket) (*wtlogin.SSOPack } func (c *QQClient) SendUniPacket(cmd string, buf []byte) error { - seq := c.getSeq() + seq := c.getAndIncreaseSequence() var sign map[string]string if c.signProvider != nil { sign = c.signProvider(cmd, seq, buf) @@ -98,27 +98,27 @@ func (c *QQClient) SendUniPacket(cmd string, buf []byte) error { } func (c *QQClient) SendUniPacketAndAwait(cmd string, buf []byte) (*wtlogin.SSOPacket, error) { - seq := c.getSeq() + seq := c.getAndIncreaseSequence() var sign map[string]string if c.signProvider != nil { sign = c.signProvider(cmd, seq, buf) } packet := wtlogin.BuildUniPacket(int(c.Uin), seq, cmd, sign, c.appInfo, c.deviceInfo, c.sig, buf) - return c.SendAndWait(packet, seq, 5) + return c.SendAndWait(packet, int(seq), 5) } func (c *QQClient) Send(data []byte) error { return c.tcp.Write(data) } -func (c *QQClient) SendAndWait(data []byte, seq, timeout int) (*wtlogin.SSOPacket, error) { - resultStore.AddSeq(seq) +func (c *QQClient) SendAndWait(data []byte, seq int, timeout int) (*wtlogin.SSOPacket, error) { + fetcher.AddSeq(seq) err := c.tcp.Write(data) if err != nil { // 出错了要删掉 - resultStore.DeleteSeq(seq) + fetcher.DeleteSeq(seq) } - return resultStore.Fecth(seq, timeout) + return fetcher.Fecth(seq, timeout) } func (c *QQClient) SSOHeartbeat(calcLatency bool) int64 { @@ -178,17 +178,17 @@ func (c *QQClient) OnMessage(msgLen int) { if packet.Seq > 0 { // uni rsp networkLogger.Debugf("%d(%d) -> %s, extra: %s", packet.Seq, packet.RetCode, packet.Cmd, packet.Extra) - if packet.RetCode != 0 && resultStore.ContainSeq(packet.Seq) { + if packet.RetCode != 0 && fetcher.ContainSeq(packet.Seq) { networkLogger.Errorf("error ssopacket retcode: %d, extra: %s", packet.RetCode, packet.Extra) return } else if packet.RetCode != 0 { networkLogger.Errorf("Unexpected error on sso layer: %d: %s", packet.RetCode, packet.Extra) return } - if !resultStore.ContainSeq(packet.Seq) { + if !fetcher.ContainSeq(packet.Seq) { networkLogger.Warningf("Unknown packet: %s(%d), ignore", packet.Cmd, packet.Seq) } else { - resultStore.AddResult(packet.Seq, packet) + fetcher.AddResult(packet.Seq, packet) } } else { // server pushed if _, ok := listeners[packet.Cmd]; ok { @@ -265,6 +265,10 @@ func (c *QQClient) OnDisconnected() { c.Online.Store(false) } -func (c *QQClient) getSeq() int { - return int(atomic.AddUint32(&c.sig.Sequence, 1) % 0x8000) +func (c *QQClient) getAndIncreaseSequence() uint32 { + return atomic.AddUint32(&c.sig.Sequence, 1) % 0x8000 +} + +func (c *QQClient) getSequence() uint32 { + return atomic.LoadUint32(&c.sig.Sequence) % 0x8000 } diff --git a/client/client.go b/client/client.go index 2954718f..689804dd 100644 --- a/client/client.go +++ b/client/client.go @@ -16,7 +16,7 @@ import ( ) var ( - resultStore = NewResultStore() + fetcher = newssofetcher() networkLogger = utils.GetLogger("network") ) diff --git a/client/message.go b/client/message.go index f310ca25..d4442f35 100644 --- a/client/message.go +++ b/client/message.go @@ -1,8 +1,6 @@ package client import ( - "sync/atomic" - message2 "github.com/LagrangeDev/LagrangeGo/message" "github.com/LagrangeDev/LagrangeGo/packets/pb/action" "github.com/LagrangeDev/LagrangeGo/packets/pb/message" @@ -20,7 +18,7 @@ func (c *QQClient) SendRawMessage(route *message.RoutingHead, body *message.Mess DivSeq: proto.Some(uint32(0)), }, Body: body, - Seq: proto.Some(atomic.LoadUint32(&c.sig.Sequence)), + Seq: proto.Some(c.getSequence()), Rand: proto.Some(crypto.RandU32()), } // grp_id not null diff --git a/client/resultStore.go b/client/resultStore.go index 8a58c4e7..d3874138 100644 --- a/client/resultStore.go +++ b/client/resultStore.go @@ -4,60 +4,57 @@ package client import ( "errors" - "sync" "time" + "github.com/RomiChan/syncx" + "github.com/LagrangeDev/LagrangeGo/packets/wtlogin" - "github.com/LagrangeDev/LagrangeGo/utils" ) -//nolint:unused -var resultLogger = utils.GetLogger("resultstore") +// var resultLogger = utils.GetLogger("resultstore") -// ResultStore 灵感来源于ddl的onebot适配器 -type ResultStore struct { - result sync.Map -} +// ssofetcher 灵感来源于ddl的onebot适配器 +type ssofetcher syncx.Map[uint32, chan *wtlogin.SSOPacket] -func NewResultStore() *ResultStore { - return &ResultStore{} +func newssofetcher() *ssofetcher { + return &ssofetcher{} } // ContainSeq 判断这个seq是否存在 -func (s *ResultStore) ContainSeq(seq int) bool { - _, ok := s.result.Load(seq) +func (s *ssofetcher) ContainSeq(seq int) bool { + _, ok := (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Load(uint32(seq)) return ok } // AddSeq 发消息的时候调用,把seq加到map里面 -func (s *ResultStore) AddSeq(seq int) { +func (s *ssofetcher) AddSeq(seq int) { resultChan := make(chan *wtlogin.SSOPacket, 1) - s.result.Store(seq, resultChan) + (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Store(uint32(seq), resultChan) } // DeleteSeq 删除seq -func (s *ResultStore) DeleteSeq(seq int) { - s.result.Delete(seq) +func (s *ssofetcher) DeleteSeq(seq int) { + (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Delete(uint32(seq)) } // AddResult 收到消息的时候调用,返回此seq是否存在,如果存在则存储数据 -func (s *ResultStore) AddResult(seq int, data *wtlogin.SSOPacket) bool { - if resultChan, ok := s.result.Load(seq); ok { - resultChan.(chan *wtlogin.SSOPacket) <- data +func (s *ssofetcher) AddResult(seq int, data *wtlogin.SSOPacket) bool { + if resultChan, ok := (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Load(uint32(seq)); ok { + resultChan <- data return true } return false } // Fecth 等待获取数据直到超时,这里找不到对应的seq会直接返回错误,务必在发包之前调用 AddSeq,如果发包出错可以 DeleteSeq -func (s *ResultStore) Fecth(seq, timeout int) (*wtlogin.SSOPacket, error) { - if resultChan, ok := s.result.Load(seq); ok { +func (s *ssofetcher) Fecth(seq, timeout int) (*wtlogin.SSOPacket, error) { + if resultChan, ok := (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Load(uint32(seq)); ok { // 确保读取完删除这个结果 - defer s.result.Delete(seq) + defer (*syncx.Map[uint32, chan *wtlogin.SSOPacket])(s).Delete(uint32(seq)) select { case <-time.After(time.Duration(timeout) * time.Second): return nil, errors.New("fetch timeout") - case result := <-(resultChan.(chan *wtlogin.SSOPacket)): + case result := <-resultChan: return result, nil } } diff --git a/go.mod b/go.mod index 55561a65..3f937fd2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/RomiChan/protobuf v0.1.1-0.20230204044148-2ed269a2e54d + github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 github.com/fumiama/gofastTEA v0.0.10 github.com/fumiama/imgsz v0.0.4 github.com/mattn/go-colorable v0.1.13 @@ -11,7 +12,6 @@ require ( ) require ( - github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/stretchr/testify v1.8.0 // indirect golang.org/x/image v0.16.0 // indirect diff --git a/packets/wtlogin/oicq.go b/packets/wtlogin/oicq.go index e78d3165..4222e075 100644 --- a/packets/wtlogin/oicq.go +++ b/packets/wtlogin/oicq.go @@ -86,7 +86,7 @@ func BuildLoginPacket(uin uint32, cmd string, appinfo *info.AppInfo, body []byte return frame } -func BuildUniPacket(uin, seq int, cmd string, sign map[string]string, +func BuildUniPacket(uin int, seq uint32, cmd string, sign map[string]string, appInfo *info.AppInfo, deviceInfo *info.DeviceInfo, sigInfo *info.SigInfo, body []byte) []byte { trace := generateTrace() diff --git a/utils/sign.go b/utils/sign.go index 2f717699..b3d3ec68 100644 --- a/utils/sign.go +++ b/utils/sign.go @@ -68,11 +68,11 @@ func containSignPKG(cmd string) bool { return ok } -func SignProvider(rawUrl string) func(string, int, []byte) map[string]string { +func SignProvider(rawUrl string) func(string, uint32, []byte) map[string]string { if rawUrl == "" { return nil } - return func(cmd string, seq int, buf []byte) map[string]string { + return func(cmd string, seq uint32, buf []byte) map[string]string { if !containSignPKG(cmd) { return nil } @@ -80,7 +80,7 @@ func SignProvider(rawUrl string) func(string, int, []byte) map[string]string { resp := signResponse{} err := httpGet(rawUrl, map[string]string{ "cmd": cmd, - "seq": strconv.Itoa(seq), + "seq": strconv.Itoa(int(seq)), "src": fmt.Sprintf("%x", buf), }, time.Duration(5)*time.Second, &resp) if err != nil {