Skip to content

Commit

Permalink
fix(network): close stream on timeout (#1520)
Browse files Browse the repository at this point in the history
  • Loading branch information
themantre authored Oct 3, 2024
1 parent a91838c commit ae57202
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 88 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ require (
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.9.0
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // Don't upgrade it! due to memory leak issue.
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7
github.com/tyler-smith/go-bip39 v1.1.0
go.nanomsg.org/mangos/v3 v3.4.2
golang.org/x/crypto v0.27.0
Expand Down
13 changes: 8 additions & 5 deletions network/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package network

import (
"fmt"
"time"

lp2pcore "github.com/libp2p/go-libp2p/core"
lp2ppeer "github.com/libp2p/go-libp2p/core/peer"
Expand All @@ -25,11 +26,12 @@ type Config struct {
ForcePrivateNetwork bool `toml:"force_private_network"`

// Private configs
NetworkName string `toml:"-"`
DefaultPort int `toml:"-"`
DefaultBootstrapAddrStrings []string `toml:"-"`
IsBootstrapper bool `toml:"-"`
PeerStorePath string `toml:"-"`
NetworkName string `toml:"-"`
DefaultPort int `toml:"-"`
DefaultBootstrapAddrStrings []string `toml:"-"`
IsBootstrapper bool `toml:"-"`
PeerStorePath string `toml:"-"`
StreamTimeout time.Duration `toml:"-"`
}

func DefaultConfig() *Config {
Expand All @@ -50,6 +52,7 @@ func DefaultConfig() *Config {
DefaultPort: 0,
IsBootstrapper: false,
PeerStorePath: "peers.json",
StreamTimeout: 20 * time.Second,
}
}

Expand Down
4 changes: 2 additions & 2 deletions network/gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ type gossipService struct {
logger *logger.SubLogger
}

func newGossipService(ctx context.Context, host lp2phost.Host, eventCh chan Event,
conf *Config, log *logger.SubLogger,
func newGossipService(ctx context.Context, host lp2phost.Host, conf *Config,
eventCh chan Event, log *logger.SubLogger,
) *gossipService {
opts := []lp2pps.Option{
lp2pps.WithFloodPublish(true),
Expand Down
10 changes: 5 additions & 5 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ func makeNetwork(conf *Config, log *logger.SubLogger, opts []lp2p.Option) (*netw
self.mdns = newMdnsService(ctx, self.host, self.logger)
}

self.dht = newDHTService(self.ctx, self.host, kadProtocolID, conf, self.logger)
self.peerMgr = newPeerMgr(ctx, host, conf, self.logger)
self.stream = newStreamService(ctx, self.host, streamProtocolID, self.eventChannel, self.logger)
self.gossip = newGossipService(ctx, self.host, self.eventChannel, conf, self.logger)
self.notifee = newNotifeeService(ctx, self.host, self.eventChannel, self.peerMgr, streamProtocolID, self.logger)
self.dht = newDHTService(ctx, host, kadProtocolID, conf, self.logger)
self.stream = newStreamService(ctx, host, conf, streamProtocolID, self.eventChannel, self.logger)
self.gossip = newGossipService(ctx, host, conf, self.eventChannel, self.logger)
self.notifee = newNotifeeService(ctx, host, self.eventChannel, self.peerMgr, streamProtocolID, self.logger)

self.logger.Info("network setup", "id", self.host.ID(),
"name", conf.NetworkName,
Expand Down Expand Up @@ -372,7 +372,7 @@ func (n *network) Protect(pid lp2pcore.PeerID, tag string) {
// It uses a goroutine to ensure that if sending is blocked, receiving messages won't be blocked.
func (n *network) SendTo(msg []byte, pid lp2pcore.PeerID) {
go func() {
err := n.stream.SendRequest(msg, pid)
_, err := n.stream.SendRequest(msg, pid)
if err != nil {
n.logger.Warn("error on sending msg", "pid", pid, "error", err)
}
Expand Down
91 changes: 39 additions & 52 deletions network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ func testConfig() *Config {
EnableMdns: false,
ForcePrivateNetwork: true,
NetworkName: "test",
DefaultPort: 12345,
DefaultPort: FindFreePort(),
PeerStorePath: util.TempFilePath(),
StreamTimeout: 10 * time.Second,
}
}

func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event {
t.Helper()

timeout := time.NewTimer(10 * time.Second)
timer := time.NewTimer(10 * time.Second)

for {
select {
case <-timeout.C:
case <-timer.C:
require.NoError(t, fmt.Errorf("shouldReceiveEvent Timeout, test: %v id:%s", t.Name(), net.SelfID().String()))

return nil
Expand All @@ -77,11 +78,11 @@ func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event {
func shouldNotReceiveEvent(t *testing.T, net *network) {
t.Helper()

timeout := time.NewTimer(100 * time.Millisecond)
timer := time.NewTimer(100 * time.Millisecond)

for {
select {
case <-timeout.C:
case <-timer.C:
return

case <-net.EventChannel():
Expand Down Expand Up @@ -131,20 +132,17 @@ func TestStoppingNetwork(t *testing.T) {
func TestNetwork(t *testing.T) {
ts := testsuite.NewTestSuite(t)

bootstrapPort := ts.RandInt32(9999) + 10000
publicPort := ts.RandInt32(9999) + 10000

// Bootstrap node
confB := testConfig()
confB.ListenAddrStrings = []string{
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", bootstrapPort),
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", confB.DefaultPort),
}
fmt.Println("Starting Bootstrap node")
networkB := makeTestNetwork(t, confB, []lp2p.Option{
lp2p.ForceReachabilityPublic(),
})
bootstrapAddresses := []string{
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%v", bootstrapPort, networkB.SelfID().String()),
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%v", confB.DefaultPort, networkB.SelfID().String()),
}

// Public and relay node
Expand All @@ -153,14 +151,14 @@ func TestNetwork(t *testing.T) {
confP.EnableRelay = false
confP.EnableRelayService = true
confP.ListenAddrStrings = []string{
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", publicPort),
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v", confP.DefaultPort),
}
fmt.Println("Starting Public node")
networkP := makeTestNetwork(t, confP, []lp2p.Option{
lp2p.ForceReachabilityPublic(),
})
publicAddrInfo, _ := lp2ppeer.AddrInfoFromString(
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%s", publicPort, networkP.SelfID()))
fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/p2p/%s", confP.DefaultPort, networkP.SelfID()))

// Private node M
confM := testConfig()
Expand Down Expand Up @@ -215,57 +213,57 @@ func TestNetwork(t *testing.T) {
t.Run("Supported Protocols", func(t *testing.T) {
fmt.Printf("Running %s\n", t.Name())

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
protos := networkM.Protocols()
assert.Contains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop)
assert.Contains(c, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
protos := networkN.Protocols()
assert.Contains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop)
assert.Contains(c, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
protos := networkP.Protocols()
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.Contains(t, protos, lp2pproto.ProtoIDv2Hop)
assert.NotContains(c, protos, lp2pproto.ProtoIDv2Stop)
assert.Contains(c, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
protos := networkX.Protocols()
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop)
assert.NotContains(c, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(c, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)
})

t.Run("Reachability", func(t *testing.T) {
fmt.Printf("Running %s\n", t.Name())

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
reachability := networkB.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
assert.Equal(c, "Public", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
reachability := networkM.ReachabilityStatus()
assert.Equal(t, "Private", reachability)
assert.Equal(c, "Private", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
reachability := networkN.ReachabilityStatus()
assert.Equal(t, "Private", reachability)
assert.Equal(c, "Private", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
reachability := networkP.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
assert.Equal(c, "Public", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
reachability := networkP.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
assert.Equal(c, "Public", reachability)
}, time.Second, 100*time.Millisecond)
})

Expand Down Expand Up @@ -421,23 +419,20 @@ func TestNetwork(t *testing.T) {
func TestConnections(t *testing.T) {
t.Parallel() // run the tests in parallel

ts := testsuite.NewTestSuite(t)

tests := []struct {
bootstrapAddr string
peerAddr string
}{
{"/ip4/127.0.0.1/tcp/%d", "/ip4/127.0.0.1/tcp/0"},
{"/ip4/127.0.0.1/udp/%d/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1"},
{"/ip6/::1/tcp/%d", "/ip6/::1/tcp/0"},
{"/ip4/127.0.0.1/udp/%d/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1"},
{"/ip6/::1/udp/%d/quic-v1", "/ip6/::1/udp/0/quic-v1"},
}

for i, test := range tests {
// Bootstrap node
confB := testConfig()
bootstrapPort := ts.RandInt32(9999) + 10000
bootstrapAddr := fmt.Sprintf(test.bootstrapAddr, bootstrapPort)
bootstrapAddr := fmt.Sprintf(test.bootstrapAddr, confB.DefaultPort)
confB.ListenAddrStrings = []string{bootstrapAddr}
fmt.Println("Starting Bootstrap node")
networkB := makeTestNetwork(t, confB, []lp2p.Option{
Expand All @@ -456,7 +451,7 @@ func TestConnections(t *testing.T) {
})

t.Run(fmt.Sprintf("Running test %d: %s <-> %s ... ",
i, test.bootstrapAddr, test.peerAddr), func(t *testing.T) {
i, bootstrapAddr, test.peerAddr), func(t *testing.T) {
t.Parallel() // run the tests in parallel

testConnection(t, networkP, networkB)
Expand All @@ -467,20 +462,12 @@ func TestConnections(t *testing.T) {
func testConnection(t *testing.T, networkP, networkB *network) {
t.Helper()

// Ensure that peers are connected to each other
for i := 0; i < 20; i++ {
if networkP.NumConnectedPeers() >= 1 &&
networkB.NumConnectedPeers() >= 1 {
break
}
time.Sleep(100 * time.Millisecond)
}

assert.Equal(t, 1, networkB.NumConnectedPeers())
assert.Equal(t, 1, networkP.NumConnectedPeers())
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.GreaterOrEqual(c, networkP.NumConnectedPeers(), 1)
assert.GreaterOrEqual(c, networkB.NumConnectedPeers(), 1)
}, 5*time.Second, 100*time.Millisecond)

msg := []byte("test-msg")

networkP.SendTo(msg, networkB.SelfID())
e := shouldReceiveEvent(t, networkB, EventTypeStream).(*StreamMessage)
assert.Equal(t, networkP.SelfID(), e.From)
Expand Down
2 changes: 1 addition & 1 deletion network/notifee.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (s *NotifeeService) Listen(_ lp2pnetwork.Network, ma multiaddr.Multiaddr) {
s.logger.Debug("notifee Listen event emitted", "addr", ma.String())
}

// ListenClose is called when your node stops listening on an address.
// ListenClose is called when the peer stops listening on an address.
func (s *NotifeeService) ListenClose(_ lp2pnetwork.Network, ma multiaddr.Multiaddr) {
// Handle listen close event if needed.
s.logger.Debug("notifee ListenClose event emitted", "addr", ma.String())
Expand Down
Loading

0 comments on commit ae57202

Please sign in to comment.