From dd74a56548c5bc68d4b9e2fb14b2d3404111d165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 9 May 2024 18:46:48 +0900 Subject: [PATCH] optimize(highway): http request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 远期目标为迁移到纯tcp的, 可并发上传且零拷贝的 https://github.com/Mrs4s/MiraiGo/tree/master/client/internal/highway --- client/base.go | 6 +- client/highway.go | 191 +++++++++++++++++----------------------- client/network.go | 24 ++--- client/richmedia.go | 19 ++-- utils/binary/builder.go | 15 ++++ utils/binary/reader.go | 160 ++++++++++++++++++--------------- 6 files changed, 204 insertions(+), 211 deletions(-) diff --git a/client/base.go b/client/base.go index d3ca9596..099d4b1c 100644 --- a/client/base.go +++ b/client/base.go @@ -18,7 +18,7 @@ import ( binary2 "github.com/LagrangeDev/LagrangeGo/utils/binary" ) -const Server = "msfwifi.3g.qq.com:8080" +const msfwifiServer = "msfwifi.3g.qq.com:8080" // NewQQclient 创建一个新的QQClient func NewQQclient(uin uint32, signUrl string, appInfo *info.AppInfo, deviceInfo *info.DeviceInfo, sig *info.SigInfo) *QQClient { @@ -31,7 +31,7 @@ func NewQQclient(uin uint32, signUrl string, appInfo *info.AppInfo, deviceInfo * // 128应该够用了吧 pushStore: make(chan *wtlogin.SSOPacket, 128), stopChan: make(chan struct{}), - tcp: NewTCPClient(Server, 5), + tcp: &TCPClient{}, cache: &cache.Cache{}, } client.Online.Store(false) @@ -233,7 +233,7 @@ func (c *QQClient) Loop() error { } func (c *QQClient) Connect() error { - err := c.tcp.Connect() + err := c.tcp.Connect(msfwifiServer, 5*time.Second) if err != nil { return err } diff --git a/client/highway.go b/client/highway.go index 58600e65..70e69b1f 100644 --- a/client/highway.go +++ b/client/highway.go @@ -2,10 +2,10 @@ package client import ( "bytes" - binary2 "encoding/binary" - "encoding/hex" + "errors" "fmt" "io" + "math/rand" "net/http" "net/url" "strconv" @@ -17,6 +17,11 @@ import ( "github.com/RomiChan/protobuf/proto" ) +const ( + uploadBlockSize = 1024 * 1024 + httpServiceType uint32 = 1 +) + type UpBlock struct { CommandId int Uin uint @@ -25,108 +30,84 @@ type UpBlock struct { Offset uint64 Ticket []byte FileMd5 []byte - Block []byte + Block io.Reader + BlockMd5 []byte + BlockSize uint32 ExtendInfo []byte Timestamp uint64 } -func (c *QQClient) GetServiceServer() ([]byte, map[uint32][]string) { +func (c *QQClient) EnsureHighwayServers() error { if c.highwayUri == nil || c.sigSession == nil { c.highwayUri = make(map[uint32][]string) packet, err := highway2.BuildHighWayUrlReq(c.sig.Tgt) if err != nil { - return nil, nil + return err } payload, err := c.SendUniPacketAndAwait("HttpConn.0x6ff_501", packet) if err != nil { - networkLogger.Errorf("Failed to get highway server: %v", err) - return nil, nil + return fmt.Errorf("get highway server: %v", err) } resp, err := highway2.ParseHighWayUrlReq(payload.Data) if err != nil { - networkLogger.Errorf("Failed to parse highway server: %v", err) - return nil, nil + return fmt.Errorf("parse highway server: %v", err) } for _, info := range resp.HttpConn.ServerInfos { servicetype := info.ServiceType for _, addr := range info.ServerAddrs { - ip := make([]byte, 4) - binary2.LittleEndian.PutUint32(ip, addr.IP) service := c.highwayUri[servicetype] - service = append(service, fmt.Sprintf("http://%d.%d.%d.%d:%d/cgi-bin/httpconn?htcmd=0x6FF0087&uin=%d", ip[0], ip[1], ip[2], ip[3], addr.Port, c.sig.Uin)) + service = append(service, fmt.Sprintf( + "http://%s:%d/cgi-bin/httpconn?htcmd=0x6FF0087&uin=%d", + le32toipstr(addr.IP), addr.Port, c.sig.Uin, + )) c.highwayUri[servicetype] = service } } c.sigSession = resp.HttpConn.SigSession } - return c.sigSession, c.highwayUri -} - -func (c *QQClient) UploadSrcByStreamAsync(commonId int, stream io.ReadSeeker, ticket []byte, md5 []byte, extendInfo []byte) bool { - // Get server URL - _, server := c.GetServiceServer() - if server == nil { - networkLogger.Errorln("Failed to get upload server") - return false - } - success := true - var upBlocks []UpBlock - data, err := io.ReadAll(stream) - if err != nil { - networkLogger.Errorln("Failed to read stream") - return false + if c.highwayUri == nil || c.sigSession == nil { + return errors.New("empty highway servers") } + return nil +} - fileSize := uint64(len(data)) - offset := uint64(0) - _, err = stream.Seek(0, io.SeekStart) +func (c *QQClient) UploadSrcByStream(commonId int, r io.Reader, fileSize uint64, md5 []byte, extendInfo []byte) error { + err := c.EnsureHighwayServers() if err != nil { - networkLogger.Errorln("Failed to seek stream") - return false - } - - for offset < fileSize { - var buffersize uint64 - if uint64(1024*1024) > fileSize-offset { - buffersize = fileSize - offset - } else { - buffersize = uint64(1024 * 1024) + return err + } + servers := c.highwayUri[httpServiceType] + server := servers[rand.Intn(len(servers))] + buffer := make([]byte, uploadBlockSize) + for offset := uint64(0); offset < fileSize; offset += uploadBlockSize { + if uploadBlockSize > fileSize-offset { + buffer = buffer[:fileSize-offset] } - buffer := make([]byte, buffersize) - payload, err := io.ReadFull(stream, buffer) + _, err := io.ReadFull(r, buffer) if err != nil { - networkLogger.Errorln("Failed to read stream") - return false + return err } - reqBody := UpBlock{ + err = c.SendUpBlock(&UpBlock{ CommandId: commonId, Uin: uint(c.sig.Uin), Sequence: uint(c.highwaySequence.Add(1)), FileSize: fileSize, Offset: offset, - Ticket: ticket, + Ticket: c.sigSession, FileMd5: md5, - Block: buffer, + Block: bytes.NewReader(buffer), + BlockMd5: crypto.MD5Digest(buffer), + BlockSize: uint32(len(buffer)), ExtendInfo: extendInfo, - } - upBlocks = append(upBlocks, reqBody) - offset += uint64(payload) - // 4 is HighwayConcurrent - if len(upBlocks) >= 4 || offset == fileSize { - for _, block := range upBlocks { - success = success && c.SendUpBlockAsync(block, server[1][0]) - if !success { - networkLogger.Errorln("Failed to send block") - return false - } - } - upBlocks = nil + }, server) + if err != nil { + return err } } - return success + return nil } -func (c *QQClient) SendUpBlockAsync(block UpBlock, server string) bool { +func (c *QQClient) SendUpBlock(block *UpBlock, server string) error { head := &highway.DataHighwayHead{ Version: 1, Uin: proto.Some(strconv.Itoa(int(block.Uin))), @@ -137,15 +118,14 @@ func (c *QQClient) SendUpBlockAsync(block UpBlock, server string) bool { DataFlag: 16, CommandId: uint32(block.CommandId), } - md5 := crypto.MD5Digest(block.Block) segHead := &highway.SegHead{ ServiceId: proto.Some(uint32(0)), Filesize: block.FileSize, DataOffset: proto.Some(block.Offset), - DataLength: uint32(len(block.Block)), + DataLength: uint32(block.BlockSize), RetCode: proto.Some(uint32(0)), ServiceTicket: block.Ticket, - Md5: md5, + Md5: block.BlockMd5, FileMd5: block.FileMd5, CacheAddr: proto.Some(uint32(0)), CachePort: proto.Some(uint32(0)), @@ -162,63 +142,61 @@ func (c *QQClient) SendUpBlockAsync(block UpBlock, server string) bool { Timestamp: block.Timestamp, MsgLoginSigHead: loginHead, } - isEnd := block.Offset+uint64(len(block.Block)) == block.FileSize - packet := binary.NewBuilder(nil) - packet.WriteBytes(block.Block, false) - payload, err := SendPacketAsync(highwayHead, packet, server, isEnd) + isEnd := block.Offset+uint64(block.BlockSize) == block.FileSize + payload, err := sendHighwayPacket(highwayHead, block.Block, block.BlockSize, server, isEnd) if err != nil { - networkLogger.Errorln("Failed to send packet ", err) - return false + return fmt.Errorf("send highway packet: %v", err) } - resphead, respbody, err := ParsePacket(payload) + defer payload.Close() + resphead, respbody, err := parseHighwayPacket(payload) if err != nil { - networkLogger.Errorln("Failed to parse packet ", err) - return false + return fmt.Errorf("parse highway packet: %v", err) } networkLogger.Debugf("Highway Block Result: %d | %d | %x | %v", resphead.ErrorCode, resphead.MsgSegHead.RetCode.Unwrap(), resphead.BytesRspExtendInfo, respbody) - return resphead.ErrorCode == 0 + if resphead.ErrorCode != 0 { + return errors.New("highway error code: " + strconv.Itoa(int(resphead.ErrorCode))) + } + return nil } -func ParsePacket(data []byte) (head *highway.RespDataHighwayHead, body *binary.Reader, err error) { - reader := binary.NewReader(data) - if reader.ReadBytesNoCopy(1)[0] == 0x28 { - headlength := reader.ReadU32() - bodylength := reader.ReadU32() - head = &highway.RespDataHighwayHead{} - headraw := reader.ReadBytesNoCopy(int(int64(headlength))) - err = proto.Unmarshal(headraw, head) - if err != nil { - return nil, nil, err - } - body = binary.NewReader(reader.ReadBytesNoCopy(int(bodylength))) - if reader.ReadBytesNoCopy(1)[0] == 0x29 { - return head, body, nil - } +func parseHighwayPacket(data io.Reader) (head *highway.RespDataHighwayHead, body *binary.Reader, err error) { + reader := binary.ParseReader(data) + if reader.ReadBytesNoCopy(1)[0] != 0x28 { + return nil, nil, errors.New("invalid highway packet") + } + headlength := reader.ReadU32() + _ = reader.ReadU32() // body len + head = &highway.RespDataHighwayHead{} + headraw := reader.ReadBytesNoCopy(int(int64(headlength))) + err = proto.Unmarshal(headraw, head) + if err != nil { + return nil, nil, err } - return nil, nil, err + if reader.ReadBytesNoCopy(1)[0] != 0x29 { + return nil, nil, errors.New("invalid highway head") + } + return head, reader, nil } -func SendPacketAsync(packet *highway.ReqDataHighwayHead, buffer *binary.Builder, serverURL string, end bool) ([]byte, error) { +func sendHighwayPacket(packet *highway.ReqDataHighwayHead, buffer io.Reader, bufferSize uint32, serverURL string, end bool) (io.ReadCloser, error) { marshal, err := proto.Marshal(packet) if err != nil { return nil, err } - println(hex.EncodeToString(marshal)) - writer := binary.NewBuilder(nil). WriteBytes([]byte{0x28}, false). WriteU32(uint32(len(marshal))). - WriteU32(uint32(buffer.Len())). - WriteBytes(marshal, false). - WriteBytes(buffer.ToBytes(), false). - WriteBytes([]byte{0x29}, false) + WriteU32(bufferSize). + WriteBytes(marshal, false) + _, _ = io.Copy(writer, buffer) + writer.Write([]byte{0x29}) - return SendDataAsync(writer.ToBytes(), serverURL, end) + return postHighwayContent(writer.ToReader(), serverURL, end) } -func SendDataAsync(packet []byte, serverURL string, end bool) ([]byte, error) { +func postHighwayContent(content io.Reader, serverURL string, end bool) (io.ReadCloser, error) { // Parse server URL server, err := url.Parse(serverURL) if err != nil { @@ -226,7 +204,7 @@ func SendDataAsync(packet []byte, serverURL string, end bool) ([]byte, error) { } // Create request - content := bytes.NewBuffer(packet) + networkLogger.Debugln("post content to highway url:", server) req, err := http.NewRequest("POST", server.String(), content) if err != nil { return nil, err @@ -244,12 +222,5 @@ func SendDataAsync(packet []byte, serverURL string, end bool) ([]byte, error) { if err != nil { return nil, err } - defer resp.Body.Close() - - // Read response data - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return data, nil + return resp.Body, nil } diff --git a/client/network.go b/client/network.go index ac0081c8..cefb75ce 100644 --- a/client/network.go +++ b/client/network.go @@ -1,5 +1,7 @@ package client +// from https://github.com/Mrs4s/MiraiGo/blob/master/client/internal/network/conn.go + import ( "errors" "io" @@ -12,23 +14,11 @@ var ErrConnectionClosed = errors.New("connection closed") type TCPClient struct { lock sync.RWMutex - - addr string - conn net.Conn - timeout int - - connected bool -} - -func NewTCPClient(addr string, timeout int) *TCPClient { - return &TCPClient{ - addr: addr, - timeout: timeout, - } + conn net.Conn } -func (c *TCPClient) Connect() error { - conn, err := net.DialTimeout("tcp", c.addr, time.Duration(c.timeout)*time.Second) +func (c *TCPClient) Connect(addr string, timeout time.Duration) error { + conn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return err } @@ -36,7 +26,6 @@ func (c *TCPClient) Connect() error { c.lock.Lock() defer c.lock.Unlock() c.conn = conn - c.connected = true return nil } @@ -75,7 +64,6 @@ func (c *TCPClient) Close() { _ = c.conn.Close() networkLogger.Error("tcp closed") c.conn = nil - c.connected = false } } @@ -86,5 +74,5 @@ func (c *TCPClient) getConn() net.Conn { } func (c *TCPClient) IsClosed() bool { - return !c.connected + return c.conn == nil } diff --git a/client/richmedia.go b/client/richmedia.go index d6ccd77b..63044198 100644 --- a/client/richmedia.go +++ b/client/richmedia.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "encoding/hex" "errors" - "io" "net/netip" "github.com/LagrangeDev/LagrangeGo/message" @@ -80,9 +79,12 @@ func (c *QQClient) ImageUploadPrivate(targetUid string, element message.IMessage if err != nil { return nil, err } - sigSession, _ := c.GetServiceServer() - if !c.UploadSrcByStreamAsync(1003, io.ReadSeeker(bytes.NewReader(image.Stream)), sigSession, md5hash, extStream) { - return nil, errors.New("upload failed") + err = c.UploadSrcByStream(1003, + bytes.NewReader(image.Stream), uint64(len(image.Stream)), + md5hash, extStream, + ) + if err != nil { + return nil, err } } image.MsgInfo = uploadResp.Upload.MsgInfo @@ -139,9 +141,12 @@ func (c *QQClient) ImageUploadGroup(groupUin uint32, element message.IMessageEle if err != nil { return nil, err } - sigSession, _ := c.GetServiceServer() - if !c.UploadSrcByStreamAsync(1004, io.ReadSeeker(bytes.NewReader(image.Stream)), sigSession, md5hash, extStream) { - return nil, errors.New("upload failed") + err = c.UploadSrcByStream(1004, + bytes.NewReader(image.Stream), uint64(len(image.Stream)), + md5hash, extStream, + ) + if err != nil { + return nil, err } } image.MsgInfo = uploadResp.Upload.MsgInfo diff --git a/utils/binary/builder.go b/utils/binary/builder.go index 58ba1ac1..771a538e 100644 --- a/utils/binary/builder.go +++ b/utils/binary/builder.go @@ -3,6 +3,7 @@ package binary import ( "bytes" "encoding/binary" + "io" "math" "strconv" "unsafe" @@ -16,6 +17,8 @@ type Builder struct { buffer *bytes.Buffer key ftea.TEA usetea bool + io.Writer + io.ReaderFrom } func NewBuilder(key []byte) *Builder { @@ -46,6 +49,10 @@ func (b *Builder) pack(v any) error { return binary.Write(b.buffer, binary.BigEndian, v) } +func (b *Builder) ToReader() io.Reader { + return b.buffer +} + // ToBytes return data with tea encryption func (b *Builder) ToBytes() []byte { return b.data() @@ -107,6 +114,14 @@ func (b *Builder) WritePacketString(s, prefix string, withPrefix bool) *Builder return b.WritePacketBytes(utils.S2B(s), prefix, withPrefix) } +func (b *Builder) Write(p []byte) (n int, err error) { + return b.buffer.Write(p) +} + +func (b *Builder) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(b.buffer, r) +} + func (b *Builder) WriteBytes(v []byte, withLength bool) *Builder { if withLength { b.WriteU16(uint16(len(v))) diff --git a/utils/binary/reader.go b/utils/binary/reader.go index 1f6b94a4..4030dbff 100644 --- a/utils/binary/reader.go +++ b/utils/binary/reader.go @@ -2,50 +2,81 @@ package binary import ( "encoding/binary" + "io" + "strconv" + "unsafe" "github.com/LagrangeDev/LagrangeGo/utils" ) type Reader struct { + reader io.Reader buffer []byte pos int } +func ParseReader(reader io.Reader) *Reader { + return &Reader{ + reader: reader, + } +} + func NewReader(buffer []byte) *Reader { return &Reader{ buffer: buffer, - pos: 0, } } -func (r *Reader) Remain() int { - return len(r.buffer) - r.pos +func (r *Reader) String() string { + if r.reader != nil { + data, err := io.ReadAll(r.reader) + if err != nil { + return err.Error() + } + return utils.B2S(data) + } + return utils.B2S(r.buffer[r.pos:]) } func (r *Reader) ReadU8() (v uint8) { + if r.reader != nil { + _, _ = r.reader.Read(unsafe.Slice(&v, 1)) + return + } v = r.buffer[r.pos] r.pos++ return } -func (r *Reader) ReadU16() (v uint16) { - v = binary.BigEndian.Uint16(r.buffer[r.pos : r.pos+2]) - r.pos += 2 +func readint[T ~uint16 | ~uint32 | ~uint64](r *Reader) (v T) { + sz := unsafe.Sizeof(v) + buf := make([]byte, 8) + if r.reader != nil { + _, _ = r.reader.Read(buf[8-sz:]) + } else { + copy(buf[8-sz:], r.buffer[r.pos:r.pos+int(sz)]) + r.pos += int(sz) + } + v = (T)(binary.BigEndian.Uint64(buf)) return } +func (r *Reader) ReadU16() (v uint16) { + return readint[uint16](r) +} + func (r *Reader) ReadU32() (v uint32) { - v = binary.BigEndian.Uint32(r.buffer[r.pos : r.pos+4]) - r.pos += 4 - return + return readint[uint32](r) } func (r *Reader) ReadU64() (v uint64) { - v = binary.BigEndian.Uint64(r.buffer[r.pos : r.pos+8]) - r.pos += 8 - return + return readint[uint64](r) } func (r *Reader) SkipBytes(length int) { + if r.reader != nil { + _, _ = r.reader.Read(make([]byte, length)) + return + } r.pos += length } @@ -53,6 +84,9 @@ func (r *Reader) SkipBytes(length int) { // // 如需使用, 请确保 Reader 未被回收 func (r *Reader) ReadBytesNoCopy(length int) (v []byte) { + if r.reader != nil { + return r.ReadBytes(length) + } v = r.buffer[r.pos : r.pos+length] r.pos += length return @@ -61,8 +95,12 @@ func (r *Reader) ReadBytesNoCopy(length int) (v []byte) { func (r *Reader) ReadBytes(length int) (v []byte) { // 返回一个全新的数组罢 v = make([]byte, length) - copy(v, r.buffer[r.pos:r.pos+length]) - r.pos += length + if r.reader != nil { + _, _ = r.reader.Read(v) + } else { + copy(v, r.buffer[r.pos:r.pos+length]) + r.pos += length + } return } @@ -72,69 +110,50 @@ func (r *Reader) ReadString(length int) string { func (r *Reader) SkipBytesWithLength(prefix string, withPerfix bool) { var length int + switch prefix { + case "u8": + length = int(r.ReadU8()) + case "u16": + length = int(r.ReadU16()) + case "u32": + length = int(r.ReadU32()) + case "u64": + length = int(r.ReadU64()) + default: + panic("invaild prefix") + } if withPerfix { - switch prefix { - case "u8": - length = int(r.ReadU8() - 1) - case "u16": - length = int(r.ReadU16() - 2) - case "u32": - length = int(r.ReadU32() - 4) - case "u64": - length = int(r.ReadU64() - 8) - default: - panic("invaild prefix") - } - } else { - switch prefix { - case "u8": - length = int(r.ReadU8()) - case "u16": - length = int(r.ReadU16()) - case "u32": - length = int(r.ReadU32()) - case "u64": - length = int(r.ReadU64()) - default: - panic("invaild prefix") + plus, err := strconv.Atoi(prefix[1:]) + if err != nil { + panic(err) } + length -= plus / 8 } - r.pos += length + r.SkipBytes(length) } -func (r *Reader) ReadBytesWithLength(prefix string, withPerfix bool) (v []byte) { +func (r *Reader) ReadBytesWithLength(prefix string, withPerfix bool) []byte { var length int + switch prefix { + case "u8": + length = int(r.ReadU8()) + case "u16": + length = int(r.ReadU16()) + case "u32": + length = int(r.ReadU32()) + case "u64": + length = int(r.ReadU64()) + default: + panic("invaild prefix") + } if withPerfix { - switch prefix { - case "u8": - length = int(r.ReadU8() - 1) - case "u16": - length = int(r.ReadU16() - 2) - case "u32": - length = int(r.ReadU32() - 4) - case "u64": - length = int(r.ReadU64() - 8) - default: - panic("invaild prefix") - } - } else { - switch prefix { - case "u8": - length = int(r.ReadU8()) - case "u16": - length = int(r.ReadU16()) - case "u32": - length = int(r.ReadU32()) - case "u64": - length = int(r.ReadU64()) - default: - panic("invaild prefix") + plus, err := strconv.Atoi(prefix[1:]) + if err != nil { + panic(err) } + length -= plus / 8 } - v = make([]byte, length) - copy(v, r.buffer[r.pos:r.pos+length]) - r.pos += length - return + return r.ReadBytes(length) } func (r *Reader) ReadStringWithLength(prefix string, withPerfix bool) string { @@ -152,11 +171,6 @@ func (r *Reader) ReadTlv() (result map[uint16][]byte) { return } -// go的残废泛型 -//func (r *Reader) ReadStruct(){ -// -//} - func (r *Reader) ReadI8() (v int8) { return int8(r.ReadU8()) }