From b71a0541bd18d8fcb8d4c6224d0ac4e27853b1f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:30:33 +0800 Subject: [PATCH] feat(p2p): add more link to tcp --- gold/head/packet.go | 4 +- gold/link/recv.go | 4 +- gold/p2p/tcp/init.go | 2 + gold/p2p/tcp/pdu.go | 13 +- gold/p2p/tcp/tcp.go | 192 +++++++++++++++++++++++---- lower/nic.go | 2 +- upper/services/tunnel/tunnel_test.go | 26 ++-- upper/services/wg/wg.go | 2 +- 8 files changed, 198 insertions(+), 47 deletions(-) diff --git a/gold/head/packet.go b/gold/head/packet.go index 62301e6..e5509f0 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -199,7 +199,7 @@ func (p *Packet) FillHash() { h := blake2b.New256() _, err := h.Write(p.Body()) if err != nil { - logrus.Error("[packet] err when fill hash:", err) + logrus.Errorln("[packet] err when fill hash:", err) return } hsh := h.Sum(p.Hash[:0]) @@ -213,7 +213,7 @@ func (p *Packet) IsVaildHash() bool { h := blake2b.New256() _, err := h.Write(p.Body()) if err != nil { - logrus.Error("[packet] err when check hash:", err) + logrus.Errorln("[packet] err when check hash:", err) return false } var sum [32]byte diff --git a/gold/link/recv.go b/gold/link/recv.go index f54535a..2d71fbd 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -47,7 +47,9 @@ func (m *Me) wait(data []byte) *head.Packet { logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc) } if m.recved.Get(crc) { - logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) + if config.ShowDebugLog { + logrus.Debugln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) + } return nil } if config.ShowDebugLog { diff --git a/gold/p2p/tcp/init.go b/gold/p2p/tcp/init.go index b6ecd67..0e8f4e0 100644 --- a/gold/p2p/tcp/init.go +++ b/gold/p2p/tcp/init.go @@ -12,6 +12,7 @@ import ( type Config struct { DialTimeout time.Duration PeersTimeout time.Duration + KeepInterval time.Duration ReceiveChannelSize int } @@ -34,6 +35,7 @@ func newEndpoint(endpoint string, configs ...any) (*EndPoint, error) { addr: net.TCPAddrFromAddrPort(addr), dialtimeout: cfg.DialTimeout, peerstimeout: cfg.PeersTimeout, + keepinterval: cfg.KeepInterval, recvchansize: cfg.ReceiveChannelSize, }, nil } diff --git a/gold/p2p/tcp/pdu.go b/gold/p2p/tcp/pdu.go index 3fe5d72..a466ea6 100644 --- a/gold/p2p/tcp/pdu.go +++ b/gold/p2p/tcp/pdu.go @@ -21,6 +21,7 @@ type packetType uint8 const ( packetTypeKeepAlive packetType = iota packetTypeNormal + packetTypeSubKeepAlive packetTypeTop ) @@ -87,7 +88,7 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) { return io.Copy(w, &buf) } -func isvalid(tcpconn *net.TCPConn) bool { +func isvalid(tcpconn *net.TCPConn) (issub, ok bool) { pckt := packet{} stopch := make(chan struct{}) @@ -107,7 +108,7 @@ func isvalid(tcpconn *net.TCPConn) bool { if config.ShowDebugLog { logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout") } - return false + return case <-copych: t.Stop() } @@ -116,17 +117,17 @@ func isvalid(tcpconn *net.TCPConn) bool { if config.ShowDebugLog { logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err) } - return false + return } - if pckt.typ != packetTypeKeepAlive { + if pckt.typ != packetTypeKeepAlive && pckt.typ != packetTypeSubKeepAlive { if config.ShowDebugLog { logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr()) } - return false + return } if config.ShowDebugLog { logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr()) } - return true + return pckt.typ == packetTypeSubKeepAlive, true } diff --git a/gold/p2p/tcp/tcp.go b/gold/p2p/tcp/tcp.go index bdce586..1ff7fbd 100644 --- a/gold/p2p/tcp/tcp.go +++ b/gold/p2p/tcp/tcp.go @@ -21,6 +21,7 @@ type EndPoint struct { addr *net.TCPAddr dialtimeout time.Duration peerstimeout time.Duration + keepinterval time.Duration recvchansize int } @@ -80,6 +81,7 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) { }), recv: make(chan *connrecv, chansz), cplk: &sync.Mutex{}, + sblk: &sync.RWMutex{}, } go conn.accept() return conn, nil @@ -91,6 +93,11 @@ type connrecv struct { pckt packet } +type subconn struct { + cplk sync.Mutex + conn *net.TCPConn +} + // Conn 伪装成无状态的有状态连接 type Conn struct { addr *EndPoint @@ -98,6 +105,8 @@ type Conn struct { peers *ttl.Cache[string, *net.TCPConn] recv chan *connrecv cplk *sync.Mutex + sblk *sync.RWMutex + subs []*subconn } func (conn *Conn) accept() { @@ -115,32 +124,54 @@ func (conn *Conn) accept() { _ = conn.Close() newc, err := conn.addr.Listen() if err != nil { - logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err) + logrus.Warnln("[tcp] re-listen on", conn.addr, "err:", err) return } *conn = *newc.(*Conn) - logrus.Info("[tcp] re-listen on", conn.addr) + logrus.Infoln("[tcp] re-listen on", conn.addr) continue } go conn.receive(tcpconn, false) } } +func delsubs(i int, subs []*subconn) []*subconn { + switch i { + case 0: + subs = subs[1:] + case len(subs) - 1: + subs = subs[:len(subs)-1] + default: + subs = append(subs[:i], subs[i+1:]...) + } + return subs +} + func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) { ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{ DialTimeout: conn.addr.dialtimeout, PeersTimeout: conn.addr.peerstimeout, + KeepInterval: conn.addr.keepinterval, ReceiveChannelSize: conn.addr.recvchansize, }) + issub, ok := false, false + if !hasvalidated { - if !isvalid(tcpconn) { + issub, ok = isvalid(tcpconn) + if !ok { return } if config.ShowDebugLog { - logrus.Debugln("[tcp] accept from", ep) + logrus.Debugln("[tcp] accept from", ep, "issub:", issub) + } + if issub { + conn.sblk.Lock() + conn.subs = append(conn.subs, &subconn{conn: tcpconn}) + conn.sblk.Unlock() + } else { + conn.peers.Set(ep.String(), tcpconn) } - conn.peers.Set(ep.String(), tcpconn) } peerstimeout := conn.addr.peerstimeout @@ -148,15 +179,33 @@ func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) { peerstimeout = time.Second * 30 } peerstimeout *= 2 - defer conn.peers.Delete(ep.String()) + if issub { + defer conn.peers.Delete(ep.String()) + } else { + defer func() { + conn.sblk.Lock() + for i, sub := range conn.subs { + if sub.conn == tcpconn { + conn.subs = delsubs(i, conn.subs) + break + } + } + conn.sblk.Unlock() + }() + } + + go conn.keep(ep) + for { r := &connrecv{addr: ep} if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil { return } - tcpconn := conn.peers.Get(ep.String()) - if tcpconn == nil { - return + if !issub { + tcpconn = conn.peers.Get(ep.String()) + if tcpconn == nil { + return + } } r.conn = tcpconn @@ -204,6 +253,46 @@ func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) { } } +func (conn *Conn) keep(ep *EndPoint) { + keepinterval := ep.keepinterval + if keepinterval < time.Second*4 { + keepinterval = time.Second * 4 + } + t := time.NewTicker(keepinterval) + defer t.Stop() + for range t.C { + if conn.addr == nil { + return + } + tcpconn := conn.peers.Get(ep.String()) + if tcpconn != nil { + _, err := io.Copy(tcpconn, &packet{typ: packetTypeKeepAlive}) + if conn.addr == nil { + return + } + if err != nil { + logrus.Warnln("[tcp] keep main conn alive to", ep, "err:", err) + conn.peers.Delete(ep.String()) + } else if config.ShowDebugLog { + logrus.Debugln("[tcp] keep main conn alive to", ep) + } + } + conn.sblk.RLock() + for i, sub := range conn.subs { + _, err := io.Copy(sub.conn, &packet{typ: packetTypeSubKeepAlive}) + if conn.addr == nil { + return + } + if err != nil { + logrus.Warnln("[tcp] keep sub conn alive to", sub.conn.RemoteAddr(), "err:", err) + conn.subs = delsubs(i, conn.subs) // del 1 link at once + break + } + } + conn.sblk.RUnlock() + } +} + func (conn *Conn) Close() error { if conn.lstn != nil { _ = conn.lstn.Close() @@ -246,20 +335,28 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) { return n, p.addr, nil } -func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) { - tcpep, ok := ep.(*EndPoint) - if !ok { - return 0, p2p.ErrEndpointTypeMistatch - } - blen := len(b) - if blen >= 65536 { - return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large") - } +// writeToPeer after acquiring lock +func (conn *Conn) writeToPeer(b []byte, tcpep *EndPoint, issub bool) (n int, err error) { retried := false - conn.cplk.Lock() - defer conn.cplk.Unlock() - tcpconn := conn.peers.Get(tcpep.String()) + ok := false + var ( + tcpconn *net.TCPConn + subc *subconn + ) RECONNECT: + if issub { + conn.sblk.RLock() + for _, sub := range conn.subs { + if sub.cplk.TryLock() { + tcpconn = sub.conn + subc = sub + break + } + } + conn.sblk.RUnlock() + } else { + tcpconn = conn.peers.Get(tcpep.String()) + } if tcpconn == nil { dialtimeout := tcpep.dialtimeout if dialtimeout < time.Second { @@ -278,9 +375,13 @@ RECONNECT: if !ok { return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String()) } - _, err = io.Copy(tcpconn, &packet{ - typ: packetTypeKeepAlive, - }) + pkt := &packet{} + if issub { + pkt.typ = packetTypeSubKeepAlive + } else { + pkt.typ = packetTypeKeepAlive + } + _, err = io.Copy(tcpconn, pkt) if err != nil { if config.ShowDebugLog { logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err) @@ -290,23 +391,58 @@ RECONNECT: if config.ShowDebugLog { logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr()) } - conn.peers.Set(tcpep.String(), tcpconn) - go conn.receive(tcpconn, true) + if !issub { + conn.peers.Set(tcpep.String(), tcpconn) + } else { + conn.sblk.Lock() + conn.subs = append(conn.subs, &subconn{conn: tcpconn}) + conn.sblk.Unlock() + go conn.receive(tcpconn, true) + } } else if config.ShowDebugLog { logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr()) } cnt, err := io.Copy(tcpconn, &packet{ typ: packetTypeNormal, - len: uint16(blen), + len: uint16(len(b)), dat: b, }) if err != nil { - conn.peers.Delete(tcpep.String()) + if subc == nil { + conn.peers.Delete(tcpep.String()) + } else { + conn.sblk.Lock() + for i, sub := range conn.subs { + if sub == subc { + conn.subs = delsubs(i, conn.subs) + break + } + } + conn.sblk.Unlock() + } if !retried { retried = true tcpconn = nil goto RECONNECT } } + if subc != nil { + subc.cplk.Unlock() + } return int(cnt) - 3, err } + +func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) { + tcpep, ok := ep.(*EndPoint) + if !ok { + return 0, p2p.ErrEndpointTypeMistatch + } + if len(b) >= 65536 { + return 0, errors.New("data size " + strconv.Itoa(len(b)) + " is too large") + } + if !conn.cplk.TryLock() { + return conn.writeToPeer(b, tcpep, true) + } + defer conn.cplk.Unlock() + return conn.writeToPeer(b, tcpep, false) +} diff --git a/lower/nic.go b/lower/nic.go index 0bdeb1b..3638e63 100644 --- a/lower/nic.go +++ b/lower/nic.go @@ -27,7 +27,7 @@ type NICIO struct { func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) *NICIO { ifce, err := water.New(water.Config{DeviceType: water.TUN}) if err != nil { - logrus.Error(err) + logrus.Errorln(err) os.Exit(1) } subn, bitsn := subnet.Mask.Size() diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index e5742ef..7f8d86f 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -107,7 +107,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1 time.Sleep(time.Second) // wait link up sendb := ([]byte)("1234") - tunnme.Write(sendb) + go tunnme.Write(sendb) buf := make([]byte, 4) tunnpeer.Read(buf) if string(sendb) != string(buf) { @@ -117,7 +117,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1 sendb = make([]byte, 4096) rand.Read(sendb) - tunnme.Write(sendb) + go tunnme.Write(sendb) buf = make([]byte, 4096) _, err = io.ReadFull(&tunnpeer, buf) if err != nil { @@ -127,13 +127,22 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1 t.Fatal("error: recv 4096 bytes data") } - sendb = make([]byte, 65535) + sendbufs := make(chan []byte, 32) + + go func() { + for i := 0; i < 32; i++ { + sendb := make([]byte, 65535) + rand.Read(sendb) + n, _ := tunnme.Write(sendb) + sendbufs <- sendb + t.Log("loop", i, "write", n, "bytes") + } + close(sendbufs) + }() buf = make([]byte, 65535) - for i := 0; i < 32; i++ { - rand.Read(sendb) - n, _ := tunnme.Write(sendb) - t.Log("loop", i, "write", n, "bytes") - n, err = io.ReadFull(&tunnpeer, buf) + i := 0 + for sendb := range sendbufs { + n, err := io.ReadFull(&tunnpeer, buf) if err != nil { t.Fatal(err) } @@ -141,6 +150,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint1 if string(sendb) != string(buf) { t.Fatal("loop", i, "error: recv 65535 bytes data") } + i++ } rand.Read(sendb) diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index 5a8b33a..b8e1f54 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -59,7 +59,7 @@ func (wg *WG) Start(srcport, destport uint16) { func (wg *WG) Run(srcport, destport uint16) { wg.init(srcport, destport) _, _ = wg.me.ListenNIC() - logrus.Info("[wg] stopped") + logrus.Infoln("[wg] stopped") } func (wg *WG) Stop() {