diff --git a/src/app/libp2p_helper/src/bitswap_storage.go b/src/app/libp2p_helper/src/bitswap_storage.go index ff7e089eef22..5a54e36eb2fa 100644 --- a/src/app/libp2p_helper/src/bitswap_storage.go +++ b/src/app/libp2p_helper/src/bitswap_storage.go @@ -4,9 +4,9 @@ import ( "context" "fmt" + "github.com/ipfs/boxo/blockstore" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" - "github.com/ipfs/boxo/blockstore" "github.com/ledgerwatch/lmdb-go/lmdb" "github.com/multiformats/go-multihash" lmdbbs "github.com/o1-labs/go-bs-lmdb" diff --git a/src/app/libp2p_helper/src/codanet.go b/src/app/libp2p_helper/src/codanet.go index 431eb76c3291..df4c732d47da 100644 --- a/src/app/libp2p_helper/src/codanet.go +++ b/src/app/libp2p_helper/src/codanet.go @@ -10,12 +10,17 @@ import ( "sync" "time" - "github.com/ipfs/boxo/bitswap" + "github.com/ipfs/boxo/bitswap" bitnet "github.com/ipfs/boxo/bitswap/network" dsb "github.com/ipfs/go-ds-badger" logging "github.com/ipfs/go-log/v2" p2p "github.com/libp2p/go-libp2p" + dht "github.com/libp2p/go-libp2p-kad-dht" + "github.com/libp2p/go-libp2p-kad-dht/dual" + pubsub "github.com/libp2p/go-libp2p-pubsub" + record "github.com/libp2p/go-libp2p-record" + p2pconfig "github.com/libp2p/go-libp2p/config" "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/control" "github.com/libp2p/go-libp2p/core/crypto" @@ -25,14 +30,9 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" - dht "github.com/libp2p/go-libp2p-kad-dht" - "github.com/libp2p/go-libp2p-kad-dht/dual" - "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoreds" - pubsub "github.com/libp2p/go-libp2p-pubsub" - record "github.com/libp2p/go-libp2p-record" - p2pconfig "github.com/libp2p/go-libp2p/config" mdns "github.com/libp2p/go-libp2p/p2p/discovery/mdns" discovery "github.com/libp2p/go-libp2p/p2p/discovery/routing" + "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoreds" libp2pyamux "github.com/libp2p/go-libp2p/p2p/muxer/yamux" p2pconnmgr "github.com/libp2p/go-libp2p/p2p/net/connmgr" "github.com/libp2p/go-libp2p/p2p/transport/tcp" @@ -89,9 +89,11 @@ func isPrivateAddr(addr ma.Multiaddr) bool { } type CodaConnectionManager struct { - p2pManager *p2pconnmgr.BasicConnMgr - OnConnect func(network.Network, network.Conn) - OnDisconnect func(network.Network, network.Conn) + p2pManager *p2pconnmgr.BasicConnMgr + onConnectMutex sync.RWMutex + onConnect func(network.Network, network.Conn) + onDisconnectMutex sync.RWMutex + onDisconnect func(network.Network, network.Conn) // protectedMirror is a map of protected peer ids/tags, mirroring the structure in // BasicConnMgr which is not accessible from CodaConnectionManager protectedMirror map[peer.ID]map[string]interface{} @@ -99,16 +101,20 @@ type CodaConnectionManager struct { } func (cm *CodaConnectionManager) AddOnConnectHandler(f func(network.Network, network.Conn)) { - prevOnConnect := cm.OnConnect - cm.OnConnect = func(net network.Network, c network.Conn) { + cm.onConnectMutex.Lock() + defer cm.onConnectMutex.Unlock() + prevOnConnect := cm.onConnect + cm.onConnect = func(net network.Network, c network.Conn) { prevOnConnect(net, c) f(net, c) } } func (cm *CodaConnectionManager) AddOnDisconnectHandler(f func(network.Network, network.Conn)) { - prevOnDisconnect := cm.OnDisconnect - cm.OnDisconnect = func(net network.Network, c network.Conn) { + cm.onDisconnectMutex.Lock() + defer cm.onDisconnectMutex.Unlock() + prevOnDisconnect := cm.onDisconnect + cm.onDisconnect = func(net network.Network, c network.Conn) { prevOnDisconnect(net, c) f(net, c) } @@ -122,8 +128,8 @@ func newCodaConnectionManager(minConnections, maxConnections int, grace time.Dur } return &CodaConnectionManager{ p2pManager: connmgr, - OnConnect: noop, - OnDisconnect: noop, + onConnect: noop, + onDisconnect: noop, protectedMirror: make(map[peer.ID]map[string]interface{}), }, nil } @@ -191,14 +197,27 @@ func (cm *CodaConnectionManager) Listen(net network.Network, addr ma.Multiaddr) func (cm *CodaConnectionManager) ListenClose(net network.Network, addr ma.Multiaddr) { cm.p2pManager.Notifee().ListenClose(net, addr) } + +func (cm *CodaConnectionManager) onConnectHandler() func(net network.Network, c network.Conn) { + cm.onConnectMutex.RLock() + defer cm.onConnectMutex.RUnlock() + return cm.onConnect +} + func (cm *CodaConnectionManager) Connected(net network.Network, c network.Conn) { logger.Debugf("%s connected to %s", c.LocalPeer(), c.RemotePeer()) - cm.OnConnect(net, c) + cm.onConnectHandler()(net, c) cm.p2pManager.Notifee().Connected(net, c) } +func (cm *CodaConnectionManager) onDisconnectHandler() func(net network.Network, c network.Conn) { + cm.onDisconnectMutex.RLock() + defer cm.onDisconnectMutex.RUnlock() + return cm.onDisconnect +} + func (cm *CodaConnectionManager) Disconnected(net network.Network, c network.Conn) { - cm.OnDisconnect(net, c) + cm.onDisconnectHandler()(net, c) cm.p2pManager.Notifee().Disconnected(net, c) } @@ -224,7 +243,6 @@ type Helper struct { ConnectionManager *CodaConnectionManager BandwidthCounter *metrics.BandwidthCounter MsgStats *MessageStats - Seeds []peer.AddrInfo NodeStatus []byte HeartbeatPeer func(peer.ID) } @@ -273,8 +291,28 @@ func (ms *MessageStats) GetStats() *safeStats { } } -func (h *Helper) ResetGatingConfigTrustedAddrFilters() { - h.gatingState.TrustedAddrFilters = ma.NewFilters() +func (h *Helper) SetBannedPeers(newP map[peer.ID]struct{}) { + h.gatingState.bannedPeersMutex.Lock() + defer h.gatingState.bannedPeersMutex.Unlock() + h.gatingState.bannedPeers = newP +} + +func (h *Helper) SetTrustedPeers(newP map[peer.ID]struct{}) { + h.gatingState.trustedPeersMutex.Lock() + defer h.gatingState.trustedPeersMutex.Unlock() + h.gatingState.trustedPeers = newP +} + +func (h *Helper) SetTrustedAddrFilters(newF *ma.Filters) { + h.gatingState.trustedAddrFiltersMutex.Lock() + defer h.gatingState.trustedAddrFiltersMutex.Unlock() + h.gatingState.trustedAddrFilters = newF +} + +func (h *Helper) SetBannedAddrFilters(newF *ma.Filters) { + h.gatingState.bannedAddrFiltersMutex.Lock() + defer h.gatingState.bannedAddrFiltersMutex.Unlock() + h.gatingState.bannedAddrFilters = newF } // this type implements the ConnectionGating interface @@ -283,10 +321,14 @@ func (h *Helper) ResetGatingConfigTrustedAddrFilters() { type CodaGatingState struct { logger logging.EventLogger KnownPrivateAddrFilters *ma.Filters - BannedAddrFilters *ma.Filters - TrustedAddrFilters *ma.Filters - BannedPeers map[peer.ID]struct{} - TrustedPeers map[peer.ID]struct{} + bannedAddrFiltersMutex sync.RWMutex + bannedAddrFilters *ma.Filters + trustedAddrFiltersMutex sync.RWMutex + trustedAddrFilters *ma.Filters + bannedPeersMutex sync.RWMutex + bannedPeers map[peer.ID]struct{} + trustedPeersMutex sync.RWMutex + trustedPeers map[peer.ID]struct{} } type CodaGatingConfig struct { @@ -322,11 +364,11 @@ func NewCodaGatingState(config *CodaGatingConfig, knownPrivateAddrFilters *ma.Fi return &CodaGatingState{ logger: logger, - BannedAddrFilters: bannedAddrFilters, - TrustedAddrFilters: trustedAddrFilters, + bannedAddrFilters: bannedAddrFilters, + trustedAddrFilters: trustedAddrFilters, KnownPrivateAddrFilters: knownPrivateAddrFilters, - BannedPeers: bannedPeers, - TrustedPeers: trustedPeers, + bannedPeers: bannedPeers, + trustedPeers: trustedPeers, } } @@ -335,10 +377,10 @@ func (h *Helper) GatingState() *CodaGatingState { } func (h *Helper) SetGatingState(gs *CodaGatingConfig) { - h.gatingState.TrustedPeers = gs.TrustedPeers - h.gatingState.BannedPeers = gs.BannedPeers - h.gatingState.TrustedAddrFilters = gs.TrustedAddrFilters - h.gatingState.BannedAddrFilters = gs.BannedAddrFilters + h.SetTrustedPeers(gs.TrustedPeers) + h.SetBannedPeers(gs.BannedPeers) + h.SetTrustedAddrFilters(gs.TrustedAddrFilters) + h.SetBannedAddrFilters(gs.BannedAddrFilters) for _, c := range h.Host.Network().Conns() { pid := c.RemotePeer() maddr := c.RemoteMultiaddr() @@ -352,6 +394,12 @@ func (h *Helper) SetGatingState(gs *CodaGatingConfig) { } } +func (gs *CodaGatingState) TrustPeer(p peer.ID) { + gs.trustedPeersMutex.Lock() + defer gs.trustedPeersMutex.Unlock() + gs.trustedPeers[p] = struct{}{} +} + func (gs *CodaGatingState) MarkPrivateAddrAsKnown(addr ma.Multiaddr) { if isPrivateAddr(addr) && gs.KnownPrivateAddrFilters.AddrBlocked(addr) { gs.logger.Infof("marking private addr %v as known", addr) @@ -397,7 +445,9 @@ func (c connectionAllowance) isDeny() bool { } func (gs *CodaGatingState) checkPeerTrusted(p peer.ID) connectionAllowance { - _, isTrusted := gs.TrustedPeers[p] + gs.trustedPeersMutex.RLock() + defer gs.trustedPeersMutex.RUnlock() + _, isTrusted := gs.trustedPeers[p] if isTrusted { return Accept } @@ -405,7 +455,9 @@ func (gs *CodaGatingState) checkPeerTrusted(p peer.ID) connectionAllowance { } func (gs *CodaGatingState) checkPeerBanned(p peer.ID) connectionAllowance { - _, isBanned := gs.BannedPeers[p] + gs.bannedPeersMutex.RLock() + defer gs.bannedPeersMutex.RUnlock() + _, isBanned := gs.bannedPeers[p] if isBanned { return DenyBannedPeer } @@ -440,14 +492,18 @@ func (gs *CodaGatingState) checkAllowedPeer(p peer.ID) connectionAllowance { } func (gs *CodaGatingState) checkAddrTrusted(addr ma.Multiaddr) connectionAllowance { - if !gs.TrustedAddrFilters.AddrBlocked(addr) { + gs.trustedAddrFiltersMutex.RLock() + defer gs.trustedAddrFiltersMutex.RUnlock() + if !gs.trustedAddrFilters.AddrBlocked(addr) { return Accept } return Undecided } func (gs *CodaGatingState) checkAddrBanned(addr ma.Multiaddr) connectionAllowance { - if gs.BannedAddrFilters.AddrBlocked(addr) { + gs.bannedAddrFiltersMutex.RLock() + defer gs.bannedAddrFiltersMutex.RUnlock() + if gs.bannedAddrFilters.AddrBlocked(addr) { return DenyBannedAddress } return Undecided @@ -721,7 +777,6 @@ func MakeHelper(ctx context.Context, listenOn []ma.Multiaddr, externalAddr ma.Mu ConnectionManager: connManager, BandwidthCounter: bandwidthCounter, MsgStats: &MessageStats{min: math.MaxUint64}, - Seeds: seeds, HeartbeatPeer: func(p peer.ID) { lanPatcher.Heartbeat(p) wanPatcher.Heartbeat(p) diff --git a/src/app/libp2p_helper/src/codanet_test.go b/src/app/libp2p_helper/src/codanet_test.go index 572818dfcd8b..669467e559b5 100644 --- a/src/app/libp2p_helper/src/codanet_test.go +++ b/src/app/libp2p_helper/src/codanet_test.go @@ -37,7 +37,7 @@ func TestTrustedPrivateConnectionGating(t *testing.T) { allowed := gs.InterceptAddrDial(testInfo.ID, testMa) require.False(t, allowed) - gs.TrustedPeers[testInfo.ID] = struct{}{} + gs.TrustPeer(testInfo.ID) allowed = gs.InterceptAddrDial(testInfo.ID, testMa) require.True(t, allowed) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/app.go b/src/app/libp2p_helper/src/libp2p_helper/app.go index ad70a586435d..2c49af18ea3e 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/app.go +++ b/src/app/libp2p_helper/src/libp2p_helper/app.go @@ -8,35 +8,33 @@ import ( "math" "os" "strconv" - "sync" "time" ipc "libp2p_ipc" capnp "capnproto.org/go/capnp/v3" "github.com/go-errors/errors" + pubsub "github.com/libp2p/go-libp2p-pubsub" net "github.com/libp2p/go-libp2p/core/network" peer "github.com/libp2p/go-libp2p/core/peer" - pubsub "github.com/libp2p/go-libp2p-pubsub" mdns "github.com/libp2p/go-libp2p/p2p/discovery/mdns" "github.com/multiformats/go-multiaddr" "github.com/prometheus/client_golang/prometheus" ) func newApp() *app { - outChan := make(chan *capnp.Message, 1<<12) // 4kb + outChan := make(chan *capnp.Message, 1<<12) // 4096 messages stacked ctx := context.Background() return &app{ P2p: nil, Ctx: ctx, - Subs: make(map[uint64]subscription), - Topics: make(map[string]*pubsub.Topic), - ValidatorMutex: &sync.Mutex{}, - Validators: make(map[uint64]*validationStatus), - Streams: make(map[uint64]net.Stream), + _subs: make(map[uint64]subscription), + _topics: make(map[string]*pubsub.Topic), + _validators: make(map[uint64]*validationStatus), + _streams: make(map[uint64]*stream), OutChan: outChan, Out: bufio.NewWriter(os.Stdout), - AddedPeers: []peer.AddrInfo{}, + _addedPeers: []peer.AddrInfo{}, MetricsRefreshTime: time.Minute, metricsCollectionStarted: false, metricsServer: nil, @@ -48,11 +46,11 @@ func (app *app) SetConnectionHandlers() { app.setConnectionHandlersOnce.Do(func() { app.P2p.ConnectionManager.AddOnConnectHandler(func(net net.Network, c net.Conn) { app.updateConnectionMetrics() - app.writeMsg(mkPeerConnectedUpcall(peer.Encode(c.RemotePeer()))) + app.writeMsg(mkPeerConnectedUpcall(c.RemotePeer().String())) }) app.P2p.ConnectionManager.AddOnDisconnectHandler(func(net net.Network, c net.Conn) { app.updateConnectionMetrics() - app.writeMsg(mkPeerDisconnectedUpcall(peer.Encode(c.RemotePeer()))) + app.writeMsg(mkPeerDisconnectedUpcall(c.RemotePeer().String())) }) }) } @@ -64,6 +62,125 @@ func (app *app) NextId() uint64 { return app.counter } +func (app *app) AddPeers(infos ...peer.AddrInfo) { + app.addedPeersMutex.Lock() + defer app.addedPeersMutex.Unlock() + app._addedPeers = append(app._addedPeers, infos...) +} + +// GetAddedPeers returns list of peers +// +// Elements of returned slice should never be modified! +func (app *app) GetAddedPeers() []peer.AddrInfo { + app.addedPeersMutex.RLock() + defer app.addedPeersMutex.RUnlock() + return app._addedPeers +} + +func (app *app) ResetAddedPeers() { + app.addedPeersMutex.Lock() + defer app.addedPeersMutex.Unlock() + app._addedPeers = nil +} + +func (app *app) AddStream(stream_ net.Stream) uint64 { + streamIdx := app.NextId() + app.streamsMutex.Lock() + defer app.streamsMutex.Unlock() + app._streams[streamIdx] = &stream{stream: stream_} + return streamIdx +} + +func (app *app) RemoveStream(streamId uint64) (*stream, bool) { + app.streamsMutex.Lock() + defer app.streamsMutex.Unlock() + stream, ok := app._streams[streamId] + delete(app._streams, streamId) + return stream, ok +} + +func (app *app) getStream(streamId uint64) (*stream, bool) { + app.streamsMutex.RLock() + defer app.streamsMutex.RUnlock() + s, has := app._streams[streamId] + return s, has +} + +func (app *app) WriteStream(streamId uint64, data []byte) error { + if stream, ok := app.getStream(streamId); ok { + stream.mutex.Lock() + defer stream.mutex.Unlock() + + if n, err := stream.stream.Write(data); err != nil { + // TODO check that it's correct to error out, not repeat writing + _, has := app.RemoveStream(streamId) + if has { + // If stream is no longer in the *app, it means it is closed or soon to be closed by + // another goroutine + close_err := stream.stream.Close() + if close_err != nil { + app.P2p.Logger.Debugf("failed to close stream %d after encountering write failure (%s): %s", streamId, err.Error(), close_err.Error()) + } + } + return wrapError(badp2p(err), fmt.Sprintf("only wrote %d out of %d bytes", n, len(data))) + } + return nil + } + return badRPC(errors.New("unknown stream_idx")) +} + +func (app *app) AddValidator() (uint64, chan pubsub.ValidationResult) { + seqno := app.NextId() + ch := make(chan pubsub.ValidationResult) + app.validatorMutex.Lock() + defer app.validatorMutex.Unlock() + app._validators[seqno] = new(validationStatus) + app._validators[seqno].Completion = ch + return seqno, ch +} + +func (app *app) TimeoutValidator(seqno uint64) { + now := time.Now() + app.validatorMutex.Lock() + defer app.validatorMutex.Unlock() + app._validators[seqno].TimedOutAt = &now +} + +func (app *app) RemoveValidator(seqno uint64) (*validationStatus, bool) { + app.validatorMutex.Lock() + defer app.validatorMutex.Unlock() + st, ok := app._validators[seqno] + delete(app._validators, seqno) + return st, ok +} + +func (app *app) AddTopic(topicName string, topic *pubsub.Topic) { + app.topicsMutex.Lock() + defer app.topicsMutex.Unlock() + app._topics[topicName] = topic +} + +func (app *app) GetTopic(topicName string) (*pubsub.Topic, bool) { + app.topicsMutex.RLock() + defer app.topicsMutex.RUnlock() + topic, has := app._topics[topicName] + return topic, has +} + +func (app *app) AddSubscription(subId uint64, sub subscription) { + app.subsMutex.Lock() + defer app.subsMutex.Unlock() + app._subs[subId] = sub +} + +func (app *app) RemoveSubscription(subId uint64) (subscription, bool) { + app.subsMutex.Lock() + defer app.subsMutex.Unlock() + sub, ok := app._subs[subId] + delete(app._subs, subId) + return sub, ok +} + func parseMultiaddrWithID(ma multiaddr.Multiaddr, id peer.ID) (*codaPeerInfo, error) { ipComponent, tcpMaddr := multiaddr.SplitFirst(ma) if !(ipComponent.Protocol().Code == multiaddr.P_IP4 || ipComponent.Protocol().Code == multiaddr.P_IP6) { @@ -96,6 +213,7 @@ func addrInfoOfString(maddr string) (*peer.AddrInfo, error) { return info, nil } +// Writes a message back to the OCaml node func (app *app) writeMsg(msg *capnp.Message) { if app.NoUpcalls { return @@ -190,13 +308,13 @@ func (app *app) checkPeerCount() { err = prometheus.Register(peerCount) if err != nil { - app.P2p.Logger.Debugf("couldn't register peer_count; perhaps we've already done so", err.Error()) + app.P2p.Logger.Debugf("couldn't register peer_count; perhaps we've already done so: %s", err) return } err = prometheus.Register(connectedPeerCount) if err != nil { - app.P2p.Logger.Debugf("couldn't register connected_peer_count; perhaps we've already done so", err.Error()) + app.P2p.Logger.Debugf("couldn't register connected_peer_count; perhaps we've already done so: %s", err) return } diff --git a/src/app/libp2p_helper/src/libp2p_helper/bandwidth_msg.go b/src/app/libp2p_helper/src/libp2p_helper/bandwidth_msg.go index 7d8209ce68ee..cee135d24e64 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/bandwidth_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/bandwidth_msg.go @@ -17,7 +17,7 @@ func fromBandwidthInfoReq(req ipcRpcRequest) (rpcRequest, error) { return BandwidthInfoReq(i), err } -func (msg BandwidthInfoReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg BandwidthInfoReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/bitswap.go b/src/app/libp2p_helper/src/libp2p_helper/bitswap.go index 69b478f82034..f12063ed48b5 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/bitswap.go +++ b/src/app/libp2p_helper/src/libp2p_helper/bitswap.go @@ -8,7 +8,7 @@ import ( "time" "capnproto.org/go/capnp/v3" - "github.com/ipfs/boxo/bitswap" + "github.com/ipfs/boxo/bitswap" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" exchange "github.com/ipfs/go-ipfs-exchange-interface" diff --git a/src/app/libp2p_helper/src/libp2p_helper/bitswap_msg.go b/src/app/libp2p_helper/src/libp2p_helper/bitswap_msg.go index 741f5a4b4c17..ab6f18ec1401 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/bitswap_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/bitswap_msg.go @@ -97,7 +97,7 @@ func fromTestDecodeBitswapBlocksReq(req ipcRpcRequest) (rpcRequest, error) { return TestDecodeBitswapBlocksReq(i), err } -func (m TestDecodeBitswapBlocksReq) handle(app *app, seqno uint64) *capnp.Message { +func (m TestDecodeBitswapBlocksReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { blocks, err := TestDecodeBitswapBlocksReqT(m).Blocks() if err != nil { return mkRpcRespError(seqno, badRPC(err)) @@ -156,7 +156,7 @@ func fromTestEncodeBitswapBlocksReq(req ipcRpcRequest) (rpcRequest, error) { return TestEncodeBitswapBlocksReq(i), err } -func (m TestEncodeBitswapBlocksReq) handle(app *app, seqno uint64) *capnp.Message { +func (m TestEncodeBitswapBlocksReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { mr := TestEncodeBitswapBlocksReqT(m) data, err := mr.Data() diff --git a/src/app/libp2p_helper/src/libp2p_helper/config_msg.go b/src/app/libp2p_helper/src/libp2p_helper/config_msg.go index 4f455ea213c7..5ea9ee4d4250 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/config_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/config_msg.go @@ -13,12 +13,12 @@ import ( capnp "capnproto.org/go/capnp/v3" "github.com/go-errors/errors" + pubsub "github.com/libp2p/go-libp2p-pubsub" + pb "github.com/libp2p/go-libp2p-pubsub/pb" crypto "github.com/libp2p/go-libp2p/core/crypto" net "github.com/libp2p/go-libp2p/core/network" peer "github.com/libp2p/go-libp2p/core/peer" peerstore "github.com/libp2p/go-libp2p/core/peerstore" - pubsub "github.com/libp2p/go-libp2p-pubsub" - pb "github.com/libp2p/go-libp2p-pubsub/pb" discovery "github.com/libp2p/go-libp2p/p2p/discovery/routing" "github.com/multiformats/go-multiaddr" "golang.org/x/crypto/blake2b" @@ -31,12 +31,12 @@ func fromBeginAdvertisingReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.BeginAdvertising() return BeginAdvertisingReq(i), err } -func (msg BeginAdvertisingReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg BeginAdvertisingReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } app.SetConnectionHandlers() - for _, info := range app.AddedPeers { + for _, info := range app.GetAddedPeers() { app.P2p.Logger.Debug("Trying to connect to: ", info) err := app.P2p.Host.Connect(app.Ctx, info) if err != nil { @@ -293,7 +293,7 @@ func fromConfigureReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.Configure() return ConfigureReq(i), err } -func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg ConfigureReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { m, err := ConfigureReqT(msg).Config() if err != nil { return mkRpcRespError(seqno, badRPC(err)) @@ -334,7 +334,7 @@ func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } - app.AddedPeers = append(app.AddedPeers, seeds...) + app.AddPeers(seeds...) directPeersMaList, err := m.DirectPeers() if err != nil { @@ -372,12 +372,12 @@ func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message { if err != nil { return mkRpcRespError(seqno, badRPC(err)) } - gatingConfig, err := readGatingConfig(gc, app.AddedPeers) + gatingConfig, err := readGatingConfig(gc, app.GetAddedPeers()) if err != nil { return mkRpcRespError(seqno, badRPC(err)) } if gc.CleanAddedPeers() { - app.AddedPeers = nil + app.ResetAddedPeers() } stateDir, err := m.Statedir() @@ -487,7 +487,7 @@ func fromGetListeningAddrsReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.GetListeningAddrs() return GetListeningAddrsReq(i), err } -func (msg GetListeningAddrsReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg GetListeningAddrsReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -508,7 +508,7 @@ func fromGenerateKeypairReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.GenerateKeypair() return GenerateKeypairReq(i), err } -func (msg GenerateKeypairReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg GenerateKeypairReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { privk, pubk, err := crypto.GenerateEd25519Key(cryptorand.Reader) if err != nil { return mkRpcRespError(seqno, badp2p(err)) @@ -548,7 +548,7 @@ func fromListenReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.Listen() return ListenReq(i), err } -func (m ListenReq) handle(app *app, seqno uint64) *capnp.Message { +func (m ListenReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -586,20 +586,20 @@ func fromSetGatingConfigReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.SetGatingConfig() return SetGatingConfigReq(i), err } -func (m SetGatingConfigReq) handle(app *app, seqno uint64) *capnp.Message { +func (m SetGatingConfigReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } var gatingConfig *codanet.CodaGatingConfig gc, err := SetGatingConfigReqT(m).GatingConfig() if err == nil { - gatingConfig, err = readGatingConfig(gc, app.AddedPeers) + gatingConfig, err = readGatingConfig(gc, app.GetAddedPeers()) } if err != nil { return mkRpcRespError(seqno, badRPC(err)) } if gc.CleanAddedPeers() { - app.AddedPeers = nil + app.ResetAddedPeers() } app.P2p.SetGatingState(gatingConfig) @@ -616,7 +616,7 @@ func fromSetNodeStatusReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.SetNodeStatus() return SetNodeStatusReq(i), err } -func (m SetNodeStatusReq) handle(app *app, seqno uint64) *capnp.Message { +func (m SetNodeStatusReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { status, err := SetNodeStatusReqT(m).Status() if err != nil { return mkRpcRespError(seqno, badRPC(err)) diff --git a/src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go b/src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go index d5a6330492c1..2b8070932ae5 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go @@ -40,7 +40,7 @@ func TestDHTDiscovery_TwoNodes(t *testing.T) { require.NoError(t, err) appB, _ := newTestApp(t, appAInfos, true) - appB.AddedPeers = appAInfos + appB.AddPeers(appAInfos...) appB.NoMDNS = true // begin appB and appA's DHT advertising @@ -190,7 +190,7 @@ func TestConfigure(t *testing.T) { require.NoError(t, err) gc.SetIsolate(false) - resMsg := ConfigureReq(m).handle(testApp, 239) + resMsg, _ := ConfigureReq(m).handle(testApp, 239) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "configure") require.Equal(t, seqno, uint64(239)) @@ -206,7 +206,7 @@ func TestGenerateKeypair(t *testing.T) { require.NoError(t, err) testApp, _ := newTestApp(t, nil, true) - resMsg := GenerateKeypairReq(m).handle(testApp, 7839) + resMsg, _ := GenerateKeypairReq(m).handle(testApp, 7839) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "generateKeypair") require.Equal(t, seqno, uint64(7839)) @@ -239,7 +239,7 @@ func TestGetListeningAddrs(t *testing.T) { m, err := ipc.NewRootLibp2pHelperInterface_GetListeningAddrs_Request(seg) require.NoError(t, err) var mRpcSeqno uint64 = 1024 - resMsg := GetListeningAddrsReq(m).handle(testApp, mRpcSeqno) + resMsg, _ := GetListeningAddrsReq(m).handle(testApp, mRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "getListeningAddrs") require.Equal(t, seqno, mRpcSeqno) require.True(t, respSuccess.HasGetListeningAddrs()) @@ -265,7 +265,7 @@ func TestListen(t *testing.T) { require.NoError(t, iface.SetRepresentation(addrStr)) require.NoError(t, err) - resMsg := ListenReq(m).handle(testApp, 1239) + resMsg, _ := ListenReq(m).handle(testApp, 1239) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "listen") require.Equal(t, seqno, uint64(1239)) @@ -316,7 +316,7 @@ func setGatingConfigImpl(t *testing.T, app *app, allowedIps, allowedIds, bannedI gc.SetIsolate(false) var mRpcSeqno uint64 = 2003 - resMsg := SetGatingConfigReq(m).handle(app, mRpcSeqno) + resMsg, _ := SetGatingConfigReq(m).handle(app, mRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "setGatingConfig") require.Equal(t, seqno, mRpcSeqno) require.True(t, respSuccess.HasSetGatingConfig()) @@ -369,7 +369,7 @@ func TestSetNodeStatus(t *testing.T) { testStatus := []byte("test_node_status") require.NoError(t, m.SetStatus(testStatus)) - resMsg := SetNodeStatusReq(m).handle(testApp, 11239) + resMsg, _ := SetNodeStatusReq(m).handle(testApp, 11239) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "setNodeStatus") require.Equal(t, seqno, uint64(11239)) diff --git a/src/app/libp2p_helper/src/libp2p_helper/data.go b/src/app/libp2p_helper/src/libp2p_helper/data.go index 7641a0790c0c..482242a699af 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/data.go +++ b/src/app/libp2p_helper/src/libp2p_helper/data.go @@ -15,24 +15,45 @@ import ( "codanet" capnp "capnproto.org/go/capnp/v3" + pubsub "github.com/libp2p/go-libp2p-pubsub" net "github.com/libp2p/go-libp2p/core/network" peer "github.com/libp2p/go-libp2p/core/peer" - pubsub "github.com/libp2p/go-libp2p-pubsub" ) +// Stream with mutex +type stream struct { + mutex sync.Mutex + stream net.Stream +} + +func (s *stream) Reset() error { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.stream.Reset() +} + +func (s *stream) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.stream.Close() +} + type app struct { P2p *codanet.Helper Ctx context.Context - Subs map[uint64]subscription - Topics map[string]*pubsub.Topic - Validators map[uint64]*validationStatus - ValidatorMutex *sync.Mutex - Streams map[uint64]net.Stream - StreamsMutex sync.Mutex + _subs map[uint64]subscription + subsMutex sync.Mutex + _topics map[string]*pubsub.Topic + topicsMutex sync.RWMutex + _validators map[uint64]*validationStatus + validatorMutex sync.Mutex + _streams map[uint64]*stream + streamsMutex sync.RWMutex Out *bufio.Writer OutChan chan *capnp.Message Bootstrapper io.Closer - AddedPeers []peer.AddrInfo + addedPeersMutex sync.RWMutex + _addedPeers []peer.AddrInfo UnsafeNoTrustIP bool MetricsRefreshTime time.Duration metricsCollectionStarted bool @@ -54,8 +75,6 @@ type app struct { type subscription struct { Sub *pubsub.Subscription - Idx uint64 - Ctx context.Context Cancel context.CancelFunc } diff --git a/src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go b/src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go index 38b435772791..a4472c443c44 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go @@ -40,34 +40,39 @@ var pushMesssageExtractors = map[ipc.Libp2pHelperInterface_PushMessage_Which]ext ipc.Libp2pHelperInterface_PushMessage_Which_heartbeatPeer: fromHeartbeatPeerPush, } +// Handles messages coming from the OCaml process func (app *app) handleIncomingMsg(msg *ipc.Libp2pHelperInterface_Message) { if msg.HasRpcRequest() { - resp, err := func() (*capnp.Message, error) { + resp, afterWriteHandler, err := func() (*capnp.Message, func(), error) { req, err := msg.RpcRequest() if err != nil { - return nil, err + return nil, nil, err } h, err := req.Header() if err != nil { - return nil, err + return nil, nil, err } seqnoO, err := h.SequenceNumber() if err != nil { - return nil, err + return nil, nil, err } seqno := seqnoO.Seqno() extractor, foundHandler := rpcRequestExtractors[req.Which()] if !foundHandler { - return nil, errors.New("Received rpc message of an unknown type") + return nil, nil, errors.New("Received rpc message of an unknown type") } req2, err := extractor(req) if err != nil { - return nil, err + return nil, nil, err } - return req2.handle(app, seqno), nil + resp, afterWriteHandler := req2.handle(app, seqno) + return resp, afterWriteHandler, nil }() if err == nil { app.writeMsg(resp) + if afterWriteHandler != nil { + afterWriteHandler() + } } else { app.P2p.Logger.Errorf("Failed to process rpc message: %s", err) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/main_test.go b/src/app/libp2p_helper/src/libp2p_helper/main_test.go index 171365a9a3ef..bc135bafeae5 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/main_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/main_test.go @@ -5,9 +5,9 @@ import ( "context" "fmt" "io" - "io/ioutil" "os" "strings" + "sync" "testing" "time" @@ -22,6 +22,8 @@ import ( net "github.com/libp2p/go-libp2p/core/network" + gonet "net" + ipc "libp2p_ipc" "github.com/stretchr/testify/require" @@ -53,7 +55,7 @@ const ( ) func TestMplex_SendLargeMessage(t *testing.T) { - // assert we are able to send and receive a message with size up to 1 << 30 bytes + // assert we are able to send and receive a message with size up to 1 MiB appA, _ := newTestApp(t, nil, true) appA.NoDHT = true @@ -67,7 +69,7 @@ func TestMplex_SendLargeMessage(t *testing.T) { err = appB.P2p.Host.Connect(appB.Ctx, appAInfos[0]) require.NoError(t, err) - msgSize := uint64(1 << 30) + msgSize := uint64(1 << 20) withTimeoutAsync(t, func(done chan interface{}) { // create handler that reads `msgSize` bytes @@ -263,16 +265,24 @@ func TestLibp2pMetrics(t *testing.T) { require.NoError(t, err) var streamIdx uint64 = 0 + var streamMutex sync.Mutex handler := func(stream net.Stream) { handleStreamReads(appB, stream, streamIdx) + streamMutex.Lock() + defer streamMutex.Unlock() streamIdx++ } appB.P2p.Host.SetStreamHandler(testProtocol, handler) + listener, err := gonet.Listen("tcp", ":0") + if err != nil { + panic(err) + } + port := listener.Addr().(*gonet.TCPAddr).Port server := http.NewServeMux() server.Handle("/metrics", promhttp.Handler()) - go http.ListenAndServe(":9001", server) + go http.Serve(listener, server) go appB.checkPeerCount() go appB.checkMessageStats() @@ -288,11 +298,11 @@ func TestLibp2pMetrics(t *testing.T) { expectedPeerCount := len(appB.P2p.Host.Network().Peers()) expectedCurrentConnCount := appB.P2p.ConnectionManager.GetInfo().ConnCount - resp, err := http.Get("http://localhost:9001/metrics") + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/metrics", port)) require.NoError(t, err) defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) respBody := string(body) diff --git a/src/app/libp2p_helper/src/libp2p_helper/message_id_test.go b/src/app/libp2p_helper/src/libp2p_helper/message_id_test.go index 8677e3527180..215bf0a1bd68 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/message_id_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/message_id_test.go @@ -48,11 +48,18 @@ func testPubsubMsgIdFun(t *testing.T, topic string) { // Subscribe to the topic testSubscribeDo(t, alice, topic, 21, 58) + // Timeouts between subscriptions are needed because otherwise each process would try to discover peers + // and will only find that no other peers are connected to the same topic. + // That said, pubsub's implementation is imperfect + time.Sleep(time.Second) testSubscribeDo(t, bob, topic, 21, 58) + time.Sleep(time.Second) testSubscribeDo(t, carol, topic, 21, 58) + time.Sleep(time.Second) _ = testOpenStreamDo(t, bob, alice.P2p.Host, appAPort, 9900, string(newProtocol)) _ = testOpenStreamDo(t, carol, alice.P2p.Host, appAPort, 9900, string(newProtocol)) + <-trapA.IncomingStream <-trapA.IncomingStream @@ -60,8 +67,7 @@ func testPubsubMsgIdFun(t *testing.T, topic string) { testPublishDo(t, alice, topic, msg, 21) testPublishDo(t, bob, topic, msg, 21) - time.Sleep(time.Millisecond * 100) - + time.Sleep(time.Second) n := 0 loop: for { diff --git a/src/app/libp2p_helper/src/libp2p_helper/msg.go b/src/app/libp2p_helper/src/libp2p_helper/msg.go index 6acf6b3a2803..053bbd640626 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/msg.go @@ -29,7 +29,12 @@ type extractPushMessage = func(ipcPushMessage) (pushMessage, error) type ipcRpcRequest = ipc.Libp2pHelperInterface_RpcRequest type rpcRequest interface { - handle(app *app, seqno uint64) *capnp.Message + // Handles rpc request and returns response and a function to be called + // immediately after writing response to the output stream + // + // Callback is needed in some cases to make sure response is written + // before some other messages might get written to the output stream + handle(app *app, seqno uint64) (*capnp.Message, func()) } type extractRequest = func(ipcRpcRequest) (rpcRequest, error) @@ -207,7 +212,7 @@ func setNanoTime(ns *ipc.UnixNano, t time.Time) { ns.SetNanoSec(t.UnixNano()) } -func mkRpcRespError(seqno uint64, rpcRespErr error) *capnp.Message { +func mkRpcRespErrorNoFunc(seqno uint64, rpcRespErr error) *capnp.Message { if rpcRespErr == nil { panic("mkRpcRespError: nil error") } @@ -228,7 +233,11 @@ func mkRpcRespError(seqno uint64, rpcRespErr error) *capnp.Message { }) } -func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) *capnp.Message { +func mkRpcRespError(seqno uint64, rpcRespErr error) (*capnp.Message, func()) { + return mkRpcRespErrorNoFunc(seqno, rpcRespErr), nil +} + +func mkRpcRespSuccessNoFunc(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) *capnp.Message { return mkMsg(func(seg *capnp.Segment) { m, err := ipc.NewRootDaemonInterface_Message(seg) panicOnErr(err) @@ -248,6 +257,10 @@ func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcRespons }) } +func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) (*capnp.Message, func()) { + return mkRpcRespSuccessNoFunc(seqno, f), nil +} + func mkPushMsg(f func(ipc.DaemonInterface_PushMessage)) *capnp.Message { return mkMsg(func(seg *capnp.Segment) { m, err := ipc.NewRootDaemonInterface_Message(seg) diff --git a/src/app/libp2p_helper/src/libp2p_helper/multinode_test.go b/src/app/libp2p_helper/src/libp2p_helper/multinode_test.go index 62bd54584cfd..5b2a66de368f 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/multinode_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/multinode_test.go @@ -15,10 +15,10 @@ import ( capnp "capnproto.org/go/capnp/v3" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" kb "github.com/libp2p/go-libp2p-kbucket" pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" ) diff --git a/src/app/libp2p_helper/src/libp2p_helper/peer_msg.go b/src/app/libp2p_helper/src/libp2p_helper/peer_msg.go index 0b37f4010b93..63b08f587486 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/peer_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/peer_msg.go @@ -22,7 +22,7 @@ func fromAddPeerReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.AddPeer() return AddPeerReq(i), err } -func (m AddPeerReq) handle(app *app, seqno uint64) *capnp.Message { +func (m AddPeerReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -40,19 +40,14 @@ func (m AddPeerReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } - app.AddedPeers = append(app.AddedPeers, *info) - app.P2p.GatingState().TrustedPeers[info.ID] = struct{}{} + app.AddPeers(*info) + app.P2p.GatingState().TrustPeer(info.ID) if app.Bootstrapper != nil { app.Bootstrapper.Close() } app.P2p.Logger.Info("addPeer Trying to connect to: ", info) - - if AddPeerReqT(m).IsSeed() { - app.P2p.Seeds = append(app.P2p.Seeds, *info) - } - err = app.P2p.Host.Connect(app.Ctx, *info) if err != nil { return mkRpcRespError(seqno, badp2p(err)) @@ -71,7 +66,7 @@ func fromGetPeerNodeStatusReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.GetPeerNodeStatus() return GetPeerNodeStatusReq(i), err } -func (m GetPeerNodeStatusReq) handle(app *app, seqno uint64) *capnp.Message { +func (m GetPeerNodeStatusReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { ctx, cancel := context.WithTimeout(app.Ctx, codanet.NodeStatusTimeout) defer cancel() pma, err := GetPeerNodeStatusReqT(m).Peer() @@ -147,7 +142,7 @@ func fromListPeersReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.ListPeers() return ListPeersReq(i), err } -func (msg ListPeersReq) handle(app *app, seqno uint64) *capnp.Message { +func (msg ListPeersReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/peer_msg_test.go b/src/app/libp2p_helper/src/libp2p_helper/peer_msg_test.go index b1bde309b0fb..5cddac5b95f6 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/peer_msg_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/peer_msg_test.go @@ -28,7 +28,7 @@ func testAddPeerImplDo(t *testing.T, node *app, peerAddr peer.AddrInfo, isSeed b m.SetIsSeed(isSeed) var mRpcSeqno uint64 = 2000 - resMsg := AddPeerReq(m).handle(node, mRpcSeqno) + resMsg, _ := AddPeerReq(m).handle(node, mRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "addPeer") require.Equal(t, seqno, mRpcSeqno) require.True(t, respSuccess.HasAddPeer()) @@ -88,7 +88,7 @@ func TestGetPeerNodeStatus(t *testing.T) { require.NoError(t, ma.SetRepresentation(addr)) var mRpcSeqno uint64 = 18900 - resMsg := GetPeerNodeStatusReq(m).handle(appB, mRpcSeqno) + resMsg, _ := GetPeerNodeStatusReq(m).handle(appB, mRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "getPeerNodeStatus") require.Equal(t, seqno, mRpcSeqno) require.True(t, respSuccess.HasGetPeerNodeStatus()) @@ -108,7 +108,7 @@ func TestListPeers(t *testing.T) { require.NoError(t, err) var mRpcSeqno uint64 = 2002 - resMsg := ListPeersReq(m).handle(appB, mRpcSeqno) + resMsg, _ := ListPeersReq(m).handle(appB, mRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "listPeers") require.Equal(t, seqno, mRpcSeqno) require.True(t, respSuccess.HasListPeers()) diff --git a/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg.go b/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg.go index 11e23382468c..12a167be5e13 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg.go @@ -8,8 +8,8 @@ import ( capnp "capnproto.org/go/capnp/v3" "github.com/go-errors/errors" - peer "github.com/libp2p/go-libp2p/core/peer" pubsub "github.com/libp2p/go-libp2p-pubsub" + peer "github.com/libp2p/go-libp2p/core/peer" ) type ValidationPushT = ipc.Libp2pHelperInterface_Validation @@ -35,26 +35,23 @@ func (m ValidationPush) handle(app *app) { app.P2p.Logger.Errorf("handleValidation: error %s", err) return } + res := ValidationUnknown + switch ValidationPushT(m).Result() { + case ipc.ValidationResult_accept: + res = pubsub.ValidationAccept + case ipc.ValidationResult_reject: + res = pubsub.ValidationReject + case ipc.ValidationResult_ignore: + res = pubsub.ValidationIgnore + default: + app.P2p.Logger.Warnf("handleValidation: unknown validation result %d", ValidationPushT(m).Result()) + } seqno := vid.Id() - app.ValidatorMutex.Lock() - defer app.ValidatorMutex.Unlock() - if st, ok := app.Validators[seqno]; ok { - res := ValidationUnknown - switch ValidationPushT(m).Result() { - case ipc.ValidationResult_accept: - res = pubsub.ValidationAccept - case ipc.ValidationResult_reject: - res = pubsub.ValidationReject - case ipc.ValidationResult_ignore: - res = pubsub.ValidationIgnore - default: - app.P2p.Logger.Warnf("handleValidation: unknown validation result %d", ValidationPushT(m).Result()) - } + if st, found := app.RemoveValidator(seqno); found { st.Completion <- res if st.TimedOutAt != nil { app.P2p.Logger.Errorf("validation for item %d took %d seconds", seqno, time.Now().Add(validationTimeout).Sub(*st.TimedOutAt)) } - delete(app.Validators, seqno) } else { app.P2p.Logger.Warnf("handleValidation: validation seqno %d unknown", seqno) } @@ -67,7 +64,7 @@ func fromPublishReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.Publish() return PublishReq(i), err } -func (m PublishReq) handle(app *app, seqno uint64) *capnp.Message { +func (m PublishReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -87,12 +84,12 @@ func (m PublishReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } - if topic, has = app.Topics[topicName]; !has { + if topic, has = app.GetTopic(topicName); !has { topic, err = app.P2p.Pubsub.Join(topicName) if err != nil { return mkRpcRespError(seqno, badp2p(err)) } - app.Topics[topicName] = topic + app.AddTopic(topicName, topic) } if err := topic.Publish(app.Ctx, data); err != nil { @@ -112,7 +109,7 @@ func fromSubscribeReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.Subscribe() return SubscribeReq(i), err } -func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { +func (m SubscribeReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -136,7 +133,7 @@ func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badp2p(err)) } - app.Topics[topicName] = topic + app.AddTopic(topicName, topic) err = app.P2p.Pubsub.RegisterTopicValidator(topicName, func(ctx context.Context, id peer.ID, msg *pubsub.Message) pubsub.ValidationResult { app.P2p.Logger.Debugf("Received gossip message on topic %s from %s", topicName, id.Pretty()) @@ -148,12 +145,7 @@ func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { seenAt := time.Now() - seqno := app.NextId() - ch := make(chan pubsub.ValidationResult) - app.ValidatorMutex.Lock() - app.Validators[seqno] = new(validationStatus) - app.Validators[seqno].Completion = ch - app.ValidatorMutex.Unlock() + seqno, ch := app.AddValidator() app.P2p.Logger.Info("validating a new pubsub message ...") @@ -161,17 +153,14 @@ func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { if err != nil && !app.UnsafeNoTrustIP { app.P2p.Logger.Errorf("failed to connect to peer %s that just sent us a pubsub message, dropping it", peer.Encode(id)) - app.ValidatorMutex.Lock() - defer app.ValidatorMutex.Unlock() - delete(app.Validators, seqno) + app.RemoveValidator(seqno) return pubsub.ValidationIgnore } deadline, ok := ctx.Deadline() if !ok { app.P2p.Logger.Errorf("no deadline set on validation context") - defer app.ValidatorMutex.Unlock() - delete(app.Validators, seqno) + app.RemoveValidator(seqno) return pubsub.ValidationIgnore } app.writeMsg(mkGossipReceivedUpcall(sender, deadline, seenAt, msg.Data, seqno, subId)) @@ -187,12 +176,7 @@ func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { validationTimeoutMetric.Inc() - app.ValidatorMutex.Lock() - - now := time.Now() - app.Validators[seqno].TimedOutAt = &now - - app.ValidatorMutex.Unlock() + app.TimeoutValidator(seqno) if app.UnsafeNoTrustIP { app.P2p.Logger.Info("validated anyway!") @@ -228,12 +212,11 @@ func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message { } ctx, cancel := context.WithCancel(app.Ctx) - app.Subs[subId] = subscription{ + app.AddSubscription(subId, subscription{ Sub: sub, - Idx: subId, - Ctx: ctx, Cancel: cancel, - } + }) + go func() { for { _, err = sub.Next(ctx) @@ -259,7 +242,7 @@ func fromUnsubscribeReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.Unsubscribe() return UnsubscribeReq(i), err } -func (m UnsubscribeReq) handle(app *app, seqno uint64) *capnp.Message { +func (m UnsubscribeReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -268,14 +251,14 @@ func (m UnsubscribeReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } subId := subId_.Id() - if sub, ok := app.Subs[subId]; ok { + if sub, found := app.RemoveSubscription(subId); found { sub.Sub.Cancel() sub.Cancel() - delete(app.Subs, subId) return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { _, err := m.NewUnsubscribe() panicOnErr(err) }) + } else { + return mkRpcRespError(seqno, badRPC(errors.New("subscription not found"))) } - return mkRpcRespError(seqno, badRPC(errors.New("subscription not found"))) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg_test.go b/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg_test.go index 7ae62a55f510..db4f5b5a67e3 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/pubsub_msg_test.go @@ -1,6 +1,7 @@ package main import ( + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -19,7 +20,7 @@ func testPublishDo(t *testing.T, app *app, topic string, data []byte, rpcSeqno u require.NoError(t, m.SetTopic(topic)) require.NoError(t, m.SetData(data)) - resMsg := PublishReq(m).handle(app, rpcSeqno) + resMsg, _ := PublishReq(m).handle(app, rpcSeqno) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "publish") require.Equal(t, seqno, rpcSeqno) @@ -27,7 +28,7 @@ func testPublishDo(t *testing.T, app *app, topic string, data []byte, rpcSeqno u _, err = respSuccess.Publish() require.NoError(t, err) - _, has := app.Topics[topic] + _, has := app._topics[topic] require.True(t, has) } @@ -47,7 +48,7 @@ func testSubscribeDo(t *testing.T, app *app, topic string, subId uint64, rpcSeqn require.NoError(t, err) sid.SetId(subId) - resMsg := SubscribeReq(m).handle(app, rpcSeqno) + resMsg, _ := SubscribeReq(m).handle(app, rpcSeqno) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "subscribe") require.Equal(t, seqno, rpcSeqno) @@ -55,9 +56,9 @@ func testSubscribeDo(t *testing.T, app *app, topic string, subId uint64, rpcSeqn _, err = respSuccess.Subscribe() require.NoError(t, err) - _, has := app.Topics[topic] + _, has := app._topics[topic] require.True(t, has) - _, has = app.Subs[subId] + _, has = app._subs[subId] require.True(t, has) } @@ -89,7 +90,7 @@ func TestUnsubscribe(t *testing.T) { require.NoError(t, err) sid.SetId(idx) - resMsg := UnsubscribeReq(m).handle(testApp, 7739) + resMsg, _ := UnsubscribeReq(m).handle(testApp, 7739) require.NoError(t, err) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "unsubscribe") require.Equal(t, seqno, uint64(7739)) @@ -97,46 +98,37 @@ func TestUnsubscribe(t *testing.T) { _, err = respSuccess.Unsubscribe() require.NoError(t, err) - _, has := testApp.Subs[idx] + _, has := testApp._subs[idx] require.False(t, has) } func TestValidationPush(t *testing.T) { testApp, _ := newTestApp(t, nil, true) - ipcValResults := []ipc.ValidationResult{ - ipc.ValidationResult_accept, - ipc.ValidationResult_reject, - ipc.ValidationResult_ignore, + ipc2Pubsub := map[ipc.ValidationResult]pubsub.ValidationResult{ + ipc.ValidationResult_accept: pubsub.ValidationAccept, + ipc.ValidationResult_reject: pubsub.ValidationReject, + ipc.ValidationResult_ignore: pubsub.ValidationIgnore, } - pubsubValResults := []pubsub.ValidationResult{ - pubsub.ValidationAccept, - pubsub.ValidationReject, - pubsub.ValidationIgnore, - } - - for i := 0; i < len(ipcValResults); i++ { - result := ValidationUnknown - seqno := uint64(i) + for resIpc, resPS := range ipc2Pubsub { + seqno := rand.Uint64() status := &validationStatus{ - Completion: make(chan pubsub.ValidationResult), + Completion: make(chan pubsub.ValidationResult, 1), } - testApp.Validators[seqno] = status - go func() { - result = <-status.Completion - }() + testApp._validators[seqno] = status _, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) require.NoError(t, err) m, err := ipc.NewRootLibp2pHelperInterface_Validation(seg) require.NoError(t, err) validationId, err := m.NewValidationId() validationId.SetId(seqno) - m.SetResult(ipcValResults[i]) + m.SetResult(resIpc) ValidationPush(m).handle(testApp) require.NoError(t, err) - require.Equal(t, pubsubValResults[i], result) - _, has := testApp.Validators[seqno] + result := <-status.Completion + require.Equal(t, resPS, result) + _, has := testApp._validators[seqno] require.False(t, has) } } diff --git a/src/app/libp2p_helper/src/libp2p_helper/stream_msg.go b/src/app/libp2p_helper/src/libp2p_helper/stream_msg.go index bc4cc9ad827e..13b2136ea533 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/stream_msg.go +++ b/src/app/libp2p_helper/src/libp2p_helper/stream_msg.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "time" ipc "libp2p_ipc" @@ -21,7 +20,7 @@ func fromAddStreamHandlerReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.AddStreamHandler() return AddStreamHandlerReq(i), err } -func (m AddStreamHandlerReq) handle(app *app, seqno uint64) *capnp.Message { +func (m AddStreamHandlerReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -35,10 +34,7 @@ func (m AddStreamHandlerReq) handle(app *app, seqno uint64) *capnp.Message { app.P2p.Logger.Errorf("failed to parse remote connection information, silently dropping stream: %s", err.Error()) return } - streamIdx := app.NextId() - app.StreamsMutex.Lock() - defer app.StreamsMutex.Unlock() - app.Streams[streamIdx] = stream + streamIdx := app.AddStream(stream) app.writeMsg(mkIncomingStreamUpcall(peerinfo, streamIdx, protocolId)) handleStreamReads(app, stream, streamIdx) }) @@ -56,7 +52,7 @@ func fromCloseStreamReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.CloseStream() return CloseStreamReq(i), err } -func (m CloseStreamReq) handle(app *app, seqno uint64) *capnp.Message { +func (m CloseStreamReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -65,20 +61,20 @@ func (m CloseStreamReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } streamId := sid.Id() - app.StreamsMutex.Lock() - defer app.StreamsMutex.Unlock() - if stream, ok := app.Streams[streamId]; ok { - delete(app.Streams, streamId) - err := stream.Close() - if err != nil { - return mkRpcRespError(seqno, badp2p(err)) + if stream, found := app.RemoveStream(streamId); found { + if err2 := stream.Close(); err2 != nil { + err = badp2p(err2) } - return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { - _, err := m.NewCloseStream() - panicOnErr(err) - }) + } else { + err = badRPC(errors.New("unknown stream_idx")) + } + if err != nil { + return mkRpcRespError(seqno, err) } - return mkRpcRespError(seqno, badRPC(errors.New("unknown stream_idx"))) + return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { + _, err := m.NewCloseStream() + panicOnErr(err) + }) } type OpenStreamReqT = ipc.Libp2pHelperInterface_OpenStream_Request @@ -88,12 +84,11 @@ func fromOpenStreamReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.OpenStream() return OpenStreamReq(i), err } -func (m OpenStreamReq) handle(app *app, seqno uint64) *capnp.Message { +func (m OpenStreamReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } - streamIdx := app.NextId() var peerDecoded peer.ID var protocolId string err := func() error { @@ -133,16 +128,8 @@ func (m OpenStreamReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badp2p(err)) } - app.StreamsMutex.Lock() - defer app.StreamsMutex.Unlock() - app.Streams[streamIdx] = stream - go func() { - // FIXME HACK: allow time for the openStreamResult to get printed before we start inserting stream events - time.Sleep(250 * time.Millisecond) - // Note: It is _very_ important that we call handleStreamReads here -- this is how the "caller" side of the stream starts listening to the responses from the RPCs. Do not remove. - handleStreamReads(app, stream, streamIdx) - }() - return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { + streamIdx := app.AddStream(stream) + mkResponse := func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { resp, err := m.NewOpenStream() panicOnErr(err) sid, err := resp.NewStreamId() @@ -151,7 +138,10 @@ func (m OpenStreamReq) handle(app *app, seqno uint64) *capnp.Message { pi, err := resp.NewPeer() panicOnErr(err) setPeerInfo(pi, peer) - }) + } + return mkRpcRespSuccessNoFunc(seqno, mkResponse), func() { + handleStreamReads(app, stream, streamIdx) + } } type RemoveStreamHandlerReqT = ipc.Libp2pHelperInterface_RemoveStreamHandler_Request @@ -161,7 +151,7 @@ func fromRemoveStreamHandlerReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.RemoveStreamHandler() return RemoveStreamHandlerReq(i), err } -func (m RemoveStreamHandlerReq) handle(app *app, seqno uint64) *capnp.Message { +func (m RemoveStreamHandlerReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -184,7 +174,7 @@ func fromResetStreamReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.ResetStream() return ResetStreamReq(i), err } -func (m ResetStreamReq) handle(app *app, seqno uint64) *capnp.Message { +func (m ResetStreamReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -193,21 +183,20 @@ func (m ResetStreamReq) handle(app *app, seqno uint64) *capnp.Message { return mkRpcRespError(seqno, badRPC(err)) } streamId := sid.Id() - app.StreamsMutex.Lock() - if stream, ok := app.Streams[streamId]; ok { - delete(app.Streams, streamId) - app.StreamsMutex.Unlock() - err := stream.Reset() - if err != nil { - return mkRpcRespError(seqno, badp2p(err)) + if stream, found := app.RemoveStream(streamId); found { + if err2 := stream.Reset(); err2 != nil { + err = badp2p(err2) } - return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { - _, err := m.NewResetStream() - panicOnErr(err) - }) + } else { + err = badRPC(errors.New("unknown stream_idx")) + } + if err != nil { + return mkRpcRespError(seqno, err) } - app.StreamsMutex.Unlock() - return mkRpcRespError(seqno, badRPC(errors.New("unknown stream_idx"))) + return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { + _, err := m.NewResetStream() + panicOnErr(err) + }) } type SendStreamReqT = ipc.Libp2pHelperInterface_SendStream_Request @@ -217,7 +206,7 @@ func fromSendStreamReq(req ipcRpcRequest) (rpcRequest, error) { i, err := req.SendStream() return SendStreamReq(i), err } -func (m SendStreamReq) handle(app *app, seqno uint64) *capnp.Message { +func (m SendStreamReq) handle(app *app, seqno uint64) (*capnp.Message, func()) { if app.P2p == nil { return mkRpcRespError(seqno, needsConfigure()) } @@ -235,26 +224,14 @@ func (m SendStreamReq) handle(app *app, seqno uint64) *capnp.Message { } streamId := sid.Id() - // TODO Consider using a more fine-grained locking strategy, - // not using a global mutex to lock on a message sending - app.StreamsMutex.Lock() - defer app.StreamsMutex.Unlock() - if stream, ok := app.Streams[streamId]; ok { - n, err := stream.Write(data) - if err != nil { - // TODO check that it's correct to error out, not repeat writing - delete(app.Streams, streamId) - close_err := stream.Close() - if close_err != nil { - app.P2p.Logger.Errorf("failed to close stream %d after encountering write failure (%s): %s", streamId, err.Error(), close_err.Error()) - } + err = app.WriteStream(streamId, data) - return mkRpcRespError(seqno, wrapError(badp2p(err), fmt.Sprintf("only wrote %d out of %d bytes", n, len(data)))) - } - return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { - _, err := m.NewSendStream() - panicOnErr(err) - }) + if err != nil { + return mkRpcRespError(seqno, err) } - return mkRpcRespError(seqno, badRPC(errors.New("unknown stream_idx"))) + + return mkRpcRespSuccess(seqno, func(m *ipc.Libp2pHelperInterface_RpcResponseSuccess) { + _, err := m.NewSendStream() + panicOnErr(err) + }) } diff --git a/src/app/libp2p_helper/src/libp2p_helper/stream_msg_test.go b/src/app/libp2p_helper/src/libp2p_helper/stream_msg_test.go index b21b6d3f2632..8e4cd233a5de 100644 --- a/src/app/libp2p_helper/src/libp2p_helper/stream_msg_test.go +++ b/src/app/libp2p_helper/src/libp2p_helper/stream_msg_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -19,7 +20,7 @@ func testAddStreamHandlerDo(t *testing.T, protocol string, app *app, rpcSeqno ui require.NoError(t, err) require.NoError(t, m.SetProtocol(protocol)) - resMsg := AddStreamHandlerReq(m).handle(app, rpcSeqno) + resMsg, _ := AddStreamHandlerReq(m).handle(app, rpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "addStreamHandler") require.Equal(t, seqno, rpcSeqno) require.True(t, respSuccess.HasAddStreamHandler()) @@ -58,7 +59,10 @@ func testOpenStreamDo(t *testing.T, appA *app, appBHost host.Host, appBPort uint require.NoError(t, pid.SetId(appBHost.ID().String())) require.NoError(t, err) - resMsg := OpenStreamReq(m).handle(appA, rpcSeqno) + resMsg, afterWriteHandler := OpenStreamReq(m).handle(appA, rpcSeqno) + if afterWriteHandler != nil { + afterWriteHandler() + } seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "openStream") require.Equal(t, seqno, rpcSeqno) require.True(t, respSuccess.HasOpenStream()) @@ -76,7 +80,7 @@ func testOpenStreamDo(t *testing.T, appA *app, appBHost host.Host, appBPort uint require.Equal(t, appA.counter, respStreamId) - _, has := appA.Streams[respStreamId] + _, has := appA._streams[respStreamId] require.True(t, has) return respStreamId @@ -103,14 +107,14 @@ func testCloseStreamDo(t *testing.T, app *app, streamId uint64, rpcSeqno uint64) require.NoError(t, err) sid.SetId(streamId) - resMsg := CloseStreamReq(m).handle(app, rpcSeqno) + resMsg, _ := CloseStreamReq(m).handle(app, rpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "closeStream") require.Equal(t, seqno, rpcSeqno) require.True(t, respSuccess.HasCloseStream()) _, err = respSuccess.CloseStream() require.NoError(t, err) - _, has := app.Streams[streamId] + _, has := app._streams[streamId] require.False(t, has) } @@ -134,7 +138,7 @@ func TestRemoveStreamHandler(t *testing.T) { require.NoError(t, err) require.NoError(t, rsh.SetProtocol(newProtocol)) var rshRpcSeqno uint64 = 1023 - resMsg := RemoveStreamHandlerReq(rsh).handle(appB, rshRpcSeqno) + resMsg, _ := RemoveStreamHandlerReq(rsh).handle(appB, rshRpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "removeStreamHandler") require.Equal(t, seqno, rshRpcSeqno) require.True(t, respSuccess.HasRemoveStreamHandler()) @@ -151,7 +155,10 @@ func TestRemoveStreamHandler(t *testing.T) { require.NoError(t, err) var osRpcSeqno uint64 = 1026 - osResMsg := OpenStreamReq(os).handle(appA, osRpcSeqno) + osResMsg, afterWriteHandler := OpenStreamReq(os).handle(appA, osRpcSeqno) + if afterWriteHandler != nil { + afterWriteHandler() + } osRpcSeqno_, errMsg := checkRpcResponseError(t, osResMsg) require.Equal(t, osRpcSeqno, osRpcSeqno_) require.Equal(t, "libp2p error: protocols not supported: [/mina/99]", errMsg) @@ -166,14 +173,14 @@ func testResetStreamDo(t *testing.T, app *app, streamId uint64, rpcSeqno uint64) require.NoError(t, err) sid.SetId(streamId) - resMsg := ResetStreamReq(m).handle(app, rpcSeqno) + resMsg, _ := ResetStreamReq(m).handle(app, rpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "resetStream") require.Equal(t, seqno, rpcSeqno) require.True(t, respSuccess.HasResetStream()) _, err = respSuccess.ResetStream() require.NoError(t, err) - _, has := app.Streams[streamId] + _, has := app._streams[streamId] require.False(t, has) } @@ -182,6 +189,22 @@ func TestResetStream(t *testing.T) { testResetStreamDo(t, appA, streamId, 114558) } +func testSendStreamFailDo(t *testing.T, app *app, streamId uint64, msgBytes []byte, rpcSeqno uint64) { + _, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + require.NoError(t, err) + m, err := ipc.NewRootLibp2pHelperInterface_SendStream_Request(seg) + require.NoError(t, err) + msg, err := m.NewMsg() + require.NoError(t, err) + sid, err := msg.NewStreamId() + require.NoError(t, err) + sid.SetId(streamId) + require.NoError(t, msg.SetData(msgBytes)) + + resMsg, _ := SendStreamReq(m).handle(app, rpcSeqno) + checkRpcResponseError(t, resMsg) +} + func testSendStreamDo(t *testing.T, app *app, streamId uint64, msgBytes []byte, rpcSeqno uint64) { _, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) require.NoError(t, err) @@ -194,14 +217,14 @@ func testSendStreamDo(t *testing.T, app *app, streamId uint64, msgBytes []byte, sid.SetId(streamId) require.NoError(t, msg.SetData(msgBytes)) - resMsg := SendStreamReq(m).handle(app, rpcSeqno) + resMsg, _ := SendStreamReq(m).handle(app, rpcSeqno) seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "sendStream") require.Equal(t, seqno, rpcSeqno) require.True(t, respSuccess.HasSendStream()) _, err = respSuccess.SendStream() require.NoError(t, err) - _, has := app.Streams[streamId] + _, has := app._streams[streamId] require.True(t, has) } @@ -221,7 +244,7 @@ func TestOpenStreamBeforeAndAfterSetGatingConfig(t *testing.T) { aUpcallErrChan := make(chan error) launchFeedUpcallTrap(appA.P2p.Logger, appA.OutChan, aTrap, aUpcallErrChan, ctx) - appB, appBPort := newTestApp(t, appAInfos, false) + appB, appBPort := newTestApp(t, nil, false) err = appB.P2p.Host.Connect(appB.Ctx, appAInfos[0]) require.NoError(t, err) bTrap := newUpcallTrap("appB", 64, upcallDropAllMask^(1<