diff --git a/dot/network/service.go b/dot/network/service.go index 738f747378..e662cca5ec 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -593,7 +593,7 @@ func (s *Service) SendMessage(to peer.ID, msg NotificationsMessage) error { } func (s *Service) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *RequestResponseProtocol { + maxResponseSize uint64) RequestMaker { protocolID := s.host.protocolID + protocol.ID(subprotocol) return &RequestResponseProtocol{ diff --git a/dot/parachain/collator-protocol/mocks_test.go b/dot/parachain/collator-protocol/mocks_test.go index cfc5504e56..1ae610cc29 100644 --- a/dot/parachain/collator-protocol/mocks_test.go +++ b/dot/parachain/collator-protocol/mocks_test.go @@ -23,6 +23,7 @@ import ( type MockNetwork struct { ctrl *gomock.Controller recorder *MockNetworkMockRecorder + isgomock struct{} } // MockNetworkMockRecorder is the mock recorder for MockNetwork. @@ -43,41 +44,41 @@ func (m *MockNetwork) EXPECT() *MockNetworkMockRecorder { } // GetRequestResponseProtocol mocks base method. -func (m *MockNetwork) GetRequestResponseProtocol(arg0 string, arg1 time.Duration, arg2 uint64) *network.RequestResponseProtocol { +func (m *MockNetwork) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, maxResponseSize uint64) network.RequestMaker { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRequestResponseProtocol", arg0, arg1, arg2) - ret0, _ := ret[0].(*network.RequestResponseProtocol) + ret := m.ctrl.Call(m, "GetRequestResponseProtocol", subprotocol, requestTimeout, maxResponseSize) + ret0, _ := ret[0].(network.RequestMaker) return ret0 } // GetRequestResponseProtocol indicates an expected call of GetRequestResponseProtocol. -func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(subprotocol, requestTimeout, maxResponseSize any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), subprotocol, requestTimeout, maxResponseSize) } // GossipMessage mocks base method. -func (m *MockNetwork) GossipMessage(arg0 network.NotificationsMessage) { +func (m *MockNetwork) GossipMessage(msg network.NotificationsMessage) { m.ctrl.T.Helper() - m.ctrl.Call(m, "GossipMessage", arg0) + m.ctrl.Call(m, "GossipMessage", msg) } // GossipMessage indicates an expected call of GossipMessage. -func (mr *MockNetworkMockRecorder) GossipMessage(arg0 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) GossipMessage(msg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), msg) } // RegisterNotificationsProtocol mocks base method. -func (m *MockNetwork) RegisterNotificationsProtocol(arg0 protocol.ID, arg1 network.MessageType, arg2 func() (network.Handshake, error), arg3 func([]byte) (network.Handshake, error), arg4 func(peer.ID, network.Handshake) error, arg5 func([]byte) (network.NotificationsMessage, error), arg6 func(peer.ID, network.NotificationsMessage) (bool, error), arg7 func(peer.ID, network.NotificationsMessage), arg8 uint64) error { +func (m *MockNetwork) RegisterNotificationsProtocol(sub protocol.ID, messageID network.MessageType, handshakeGetter func() (network.Handshake, error), handshakeDecoder func([]byte) (network.Handshake, error), handshakeValidator func(peer.ID, network.Handshake) error, messageDecoder func([]byte) (network.NotificationsMessage, error), messageHandler func(peer.ID, network.NotificationsMessage) (bool, error), batchHandler func(peer.ID, network.NotificationsMessage), maxSize uint64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) ret0, _ := ret[0].(error) return ret0 } // RegisterNotificationsProtocol indicates an expected call of RegisterNotificationsProtocol. -func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) } diff --git a/dot/parachain/collator-protocol/validator_side.go b/dot/parachain/collator-protocol/validator_side.go index ddd5b576e6..ff6a972db0 100644 --- a/dot/parachain/collator-protocol/validator_side.go +++ b/dot/parachain/collator-protocol/validator_side.go @@ -559,7 +559,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker } type CollationEvent struct { @@ -574,7 +574,7 @@ type CollatorProtocolValidatorSide struct { SubSystemToOverseer chan<- any unfetchedCollation chan UnfetchedCollation - collationFetchingReqResProtocol *network.RequestResponseProtocol + collationFetchingReqResProtocol network.RequestMaker fetchedCollations []parachaintypes.Collation // track all active collators and their data diff --git a/dot/parachain/network-bridge/interface.go b/dot/parachain/network-bridge/interface.go index d070a3c592..baa7d6b17d 100644 --- a/dot/parachain/network-bridge/interface.go +++ b/dot/parachain/network-bridge/interface.go @@ -27,7 +27,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker ReportPeer(change peerset.ReputationChange, p peer.ID) DisconnectPeer(setID int, p peer.ID) GetNetworkEventsChannel() chan *network.NetworkEventInfo diff --git a/dot/parachain/chunk_fetching.go b/dot/parachain/network-bridge/messages/chunk_fetching.go similarity index 87% rename from dot/parachain/chunk_fetching.go rename to dot/parachain/network-bridge/messages/chunk_fetching.go index cd3086da39..aaa56f3185 100644 --- a/dot/parachain/chunk_fetching.go +++ b/dot/parachain/network-bridge/messages/chunk_fetching.go @@ -1,11 +1,13 @@ // Copyright 2023 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package parachain +package messages import ( "fmt" + "github.com/ChainSafe/gossamer/dot/network" + parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -24,6 +26,16 @@ func (c ChunkFetchingRequest) Encode() ([]byte, error) { return scale.Marshal(c) } +// Protocol returns the sub-protocol ID for this message +func (c ChunkFetchingRequest) Protocol() ReqProtocolName { + return ChunkFetchingV1 +} + +// Response returns an instance of the response type for this message, for the purpose of decoding into it. +func (c ChunkFetchingRequest) Response() network.ResponseMessage { + return &ChunkFetchingResponse{} +} + type ChunkFetchingResponseValues interface { ChunkResponse | NoSuchChunk } diff --git a/dot/parachain/chunk_fetching_test.go b/dot/parachain/network-bridge/messages/chunk_fetching_test.go similarity index 99% rename from dot/parachain/chunk_fetching_test.go rename to dot/parachain/network-bridge/messages/chunk_fetching_test.go index a44d1962f4..5aa1b0ac77 100644 --- a/dot/parachain/chunk_fetching_test.go +++ b/dot/parachain/network-bridge/messages/chunk_fetching_test.go @@ -1,7 +1,7 @@ // Copyright 2023 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package parachain +package messages import ( "testing" diff --git a/dot/parachain/network-bridge/messages/request_response_protocols.go b/dot/parachain/network-bridge/messages/request_response_protocols.go new file mode 100644 index 0000000000..820ec29fae --- /dev/null +++ b/dot/parachain/network-bridge/messages/request_response_protocols.go @@ -0,0 +1,95 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package messages + +import ( + "context" + + "github.com/ChainSafe/gossamer/dot/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +type ReqProtocolName uint + +const ( + ChunkFetchingV1 ReqProtocolName = iota + CollationFetchingV1 + PoVFetchingV1 + AvailableDataFetchingV1 + StatementFetchingV1 + DisputeSendingV1 +) + +func (n ReqProtocolName) String() string { + switch n { + case ChunkFetchingV1: + return "req_chunk/1" + case CollationFetchingV1: + return "req_collation/1" + case PoVFetchingV1: + return "req_pov/1" + case AvailableDataFetchingV1: + return "req_available_data/1" + case StatementFetchingV1: + return "req_statement/1" + case DisputeSendingV1: + return "send_dispute/1" + default: + panic("unknown protocol") + } +} + +// ReqProtocolMessage is a network message that can be sent over a request response protocol. +type ReqProtocolMessage interface { + network.Message + // Response returns an instance of the response type for this message, for the purpose of decoding into it. + Response() network.ResponseMessage + Protocol() ReqProtocolName +} + +// ReqRespResult is the result of sending a request over a request response protocol. It contains either a response +// message or an error. +type ReqRespResult struct { + Response network.ResponseMessage + Error error +} + +// OutgoingRequest contains all data required to send a request over a request response protocol and receive the result. +type OutgoingRequest struct { + Recipient peer.ID // TODO use a type that can contain either a peer ID or an authority ID + Payload ReqProtocolMessage + Result chan ReqRespResult + + ctx context.Context + cancel context.CancelFunc +} + +// Done returns a channel that is closed when the request is cancelled. +func (or *OutgoingRequest) Done() <-chan struct{} { + return or.ctx.Done() +} + +// Cancel cancels the request. +func (or *OutgoingRequest) Cancel() { + or.cancel() +} + +// IsCancelled returns true if the request has been cancelled. +func (or *OutgoingRequest) IsCancelled() bool { + return or.ctx.Err() != nil +} + +// NewOutgoingRequest creates a new outgoing request. +func NewOutgoingRequest(recipient peer.ID, payload ReqProtocolMessage) *OutgoingRequest { + result := make(chan ReqRespResult, 1) + ctx, cancel := context.WithCancel(context.Background()) + + return &OutgoingRequest{ + Recipient: recipient, + Payload: payload, + Result: result, + ctx: ctx, + cancel: cancel, + } +} diff --git a/dot/parachain/network-bridge/messages/tx_overseer_messages.go b/dot/parachain/network-bridge/messages/tx_overseer_messages.go index e7d454689e..8c41e3613c 100644 --- a/dot/parachain/network-bridge/messages/tx_overseer_messages.go +++ b/dot/parachain/network-bridge/messages/tx_overseer_messages.go @@ -63,3 +63,16 @@ type ConnectToValidators struct { // authority discovery has Failed to resolve. Failed chan<- uint } + +type IfDisconnectedBehavior int + +const ( + TryConnect IfDisconnectedBehavior = iota + ImmediateError // TODO not implemented +) + +// SendRequests is a subsystem message for sending requests over a request response protocol. +type SendRequests struct { + Requests []*OutgoingRequest + IfDisconnected IfDisconnectedBehavior +} diff --git a/dot/parachain/network-bridge/mock_request_maker_test.go b/dot/parachain/network-bridge/mock_request_maker_test.go new file mode 100644 index 0000000000..166f87e5cf --- /dev/null +++ b/dot/parachain/network-bridge/mock_request_maker_test.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/dot/network (interfaces: RequestMaker) +// +// Generated by this command: +// +// mockgen -destination=mock_request_maker_test.go -package=networkbridge github.com/ChainSafe/gossamer/dot/network RequestMaker +// + +// Package networkbridge is a generated GoMock package. +package networkbridge + +import ( + reflect "reflect" + + network "github.com/ChainSafe/gossamer/dot/network" + peer "github.com/libp2p/go-libp2p/core/peer" + gomock "go.uber.org/mock/gomock" +) + +// MockRequestMaker is a mock of RequestMaker interface. +type MockRequestMaker struct { + ctrl *gomock.Controller + recorder *MockRequestMakerMockRecorder + isgomock struct{} +} + +// MockRequestMakerMockRecorder is the mock recorder for MockRequestMaker. +type MockRequestMakerMockRecorder struct { + mock *MockRequestMaker +} + +// NewMockRequestMaker creates a new mock instance. +func NewMockRequestMaker(ctrl *gomock.Controller) *MockRequestMaker { + mock := &MockRequestMaker{ctrl: ctrl} + mock.recorder = &MockRequestMakerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestMaker) EXPECT() *MockRequestMakerMockRecorder { + return m.recorder +} + +// Do mocks base method. +func (m *MockRequestMaker) Do(to peer.ID, req network.Message, res network.ResponseMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", to, req, res) + ret0, _ := ret[0].(error) + return ret0 +} + +// Do indicates an expected call of Do. +func (mr *MockRequestMakerMockRecorder) Do(to, req, res any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockRequestMaker)(nil).Do), to, req, res) +} diff --git a/dot/parachain/network-bridge/mocks_generate_test.go b/dot/parachain/network-bridge/mocks_generate_test.go new file mode 100644 index 0000000000..9c1ac6f979 --- /dev/null +++ b/dot/parachain/network-bridge/mocks_generate_test.go @@ -0,0 +1,7 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package networkbridge + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Network +//go:generate mockgen -destination=mock_request_maker_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/dot/network RequestMaker diff --git a/dot/parachain/network-bridge/mocks_test.go b/dot/parachain/network-bridge/mocks_test.go new file mode 100644 index 0000000000..ae19f2dcef --- /dev/null +++ b/dot/parachain/network-bridge/mocks_test.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/dot/parachain/network-bridge (interfaces: Network) +// +// Generated by this command: +// +// mockgen -destination=mocks_test.go -package=networkbridge . Network +// + +// Package networkbridge is a generated GoMock package. +package networkbridge + +import ( + reflect "reflect" + time "time" + + network "github.com/ChainSafe/gossamer/dot/network" + peerset "github.com/ChainSafe/gossamer/dot/peerset" + peer "github.com/libp2p/go-libp2p/core/peer" + protocol "github.com/libp2p/go-libp2p/core/protocol" + gomock "go.uber.org/mock/gomock" +) + +// MockNetwork is a mock of Network interface. +type MockNetwork struct { + ctrl *gomock.Controller + recorder *MockNetworkMockRecorder + isgomock struct{} +} + +// MockNetworkMockRecorder is the mock recorder for MockNetwork. +type MockNetworkMockRecorder struct { + mock *MockNetwork +} + +// NewMockNetwork creates a new mock instance. +func NewMockNetwork(ctrl *gomock.Controller) *MockNetwork { + mock := &MockNetwork{ctrl: ctrl} + mock.recorder = &MockNetworkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetwork) EXPECT() *MockNetworkMockRecorder { + return m.recorder +} + +// DisconnectPeer mocks base method. +func (m *MockNetwork) DisconnectPeer(setID int, p peer.ID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeer", setID, p) +} + +// DisconnectPeer indicates an expected call of DisconnectPeer. +func (mr *MockNetworkMockRecorder) DisconnectPeer(setID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeer", reflect.TypeOf((*MockNetwork)(nil).DisconnectPeer), setID, p) +} + +// FreeNetworkEventsChannel mocks base method. +func (m *MockNetwork) FreeNetworkEventsChannel(ch chan *network.NetworkEventInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "FreeNetworkEventsChannel", ch) +} + +// FreeNetworkEventsChannel indicates an expected call of FreeNetworkEventsChannel. +func (mr *MockNetworkMockRecorder) FreeNetworkEventsChannel(ch any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FreeNetworkEventsChannel", reflect.TypeOf((*MockNetwork)(nil).FreeNetworkEventsChannel), ch) +} + +// GetNetworkEventsChannel mocks base method. +func (m *MockNetwork) GetNetworkEventsChannel() chan *network.NetworkEventInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkEventsChannel") + ret0, _ := ret[0].(chan *network.NetworkEventInfo) + return ret0 +} + +// GetNetworkEventsChannel indicates an expected call of GetNetworkEventsChannel. +func (mr *MockNetworkMockRecorder) GetNetworkEventsChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkEventsChannel", reflect.TypeOf((*MockNetwork)(nil).GetNetworkEventsChannel)) +} + +// GetRequestResponseProtocol mocks base method. +func (m *MockNetwork) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, maxResponseSize uint64) network.RequestMaker { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRequestResponseProtocol", subprotocol, requestTimeout, maxResponseSize) + ret0, _ := ret[0].(network.RequestMaker) + return ret0 +} + +// GetRequestResponseProtocol indicates an expected call of GetRequestResponseProtocol. +func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(subprotocol, requestTimeout, maxResponseSize any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), subprotocol, requestTimeout, maxResponseSize) +} + +// GossipMessage mocks base method. +func (m *MockNetwork) GossipMessage(msg network.NotificationsMessage) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "GossipMessage", msg) +} + +// GossipMessage indicates an expected call of GossipMessage. +func (mr *MockNetworkMockRecorder) GossipMessage(msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), msg) +} + +// RegisterNotificationsProtocol mocks base method. +func (m *MockNetwork) RegisterNotificationsProtocol(sub protocol.ID, messageID network.MessageType, handshakeGetter func() (network.Handshake, error), handshakeDecoder func([]byte) (network.Handshake, error), handshakeValidator func(peer.ID, network.Handshake) error, messageDecoder func([]byte) (network.NotificationsMessage, error), messageHandler func(peer.ID, network.NotificationsMessage) (bool, error), batchHandler func(peer.ID, network.NotificationsMessage), maxSize uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterNotificationsProtocol indicates an expected call of RegisterNotificationsProtocol. +func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) +} + +// ReportPeer mocks base method. +func (m *MockNetwork) ReportPeer(change peerset.ReputationChange, p peer.ID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportPeer", change, p) +} + +// ReportPeer indicates an expected call of ReportPeer. +func (mr *MockNetworkMockRecorder) ReportPeer(change, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportPeer", reflect.TypeOf((*MockNetwork)(nil).ReportPeer), change, p) +} + +// SendMessage mocks base method. +func (m *MockNetwork) SendMessage(to peer.ID, msg network.NotificationsMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", to, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage. +func (mr *MockNetworkMockRecorder) SendMessage(to, msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockNetwork)(nil).SendMessage), to, msg) +} diff --git a/dot/parachain/network-bridge/sender.go b/dot/parachain/network-bridge/sender.go index e1a2df06e8..c702dc371e 100644 --- a/dot/parachain/network-bridge/sender.go +++ b/dot/parachain/network-bridge/sender.go @@ -6,6 +6,7 @@ package networkbridge import ( "context" "fmt" + "time" "github.com/ChainSafe/gossamer/dot/network" networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" @@ -92,7 +93,9 @@ func (nbs *NetworkBridgeSender) processMessage(msg any) error { return fmt.Errorf("sending message: %w", err) } } - // TODO: add ConnectTOResolvedValidators, SendRequests + case networkbridgemessages.SendRequests: + nbs.sendRequests(msg.Requests, msg.IfDisconnected) + // TODO: add ConnectTOResolvedValidators case networkbridgemessages.ConnectToValidators: // TODO case networkbridgemessages.ReportPeer: @@ -104,3 +107,48 @@ func (nbs *NetworkBridgeSender) processMessage(msg any) error { return nil } + +const requestTimeout = 2 * time.Second + +// PoV is probably the largest message and is currently set at 5MB, but will likely be increased to 10MB in the future. +// see: https://github.com/paritytech/polkadot-sdk/issues/5334 +// Maybe message types should have a MaxSize() method instead of using the same value for all messages. +const maxResponseSize uint64 = 5 * 1024 * 1024 + +func (nbs *NetworkBridgeSender) sendRequests( + requests []*networkbridgemessages.OutgoingRequest, + ifDisconnected networkbridgemessages.IfDisconnectedBehavior, //nolint:unparam +) { + for _, request := range requests { + if request.IsCancelled() { + close(request.Result) + continue + } + + result := nbs.sendRequest(request) + + if !request.IsCancelled() { + request.Result <- result + request.Cancel() // only called here to avoid resource leaks + } + close(request.Result) + } +} + +func (nbs *NetworkBridgeSender) sendRequest( + request *networkbridgemessages.OutgoingRequest, +) networkbridgemessages.ReqRespResult { + protoID := request.Payload.Protocol().String() + protocol := nbs.net.GetRequestResponseProtocol(protoID, requestTimeout, maxResponseSize) + response := request.Payload.Response() + result := networkbridgemessages.ReqRespResult{} + + err := protocol.Do(request.Recipient, request.Payload, response) + if err != nil { + result.Error = err + } else { + result.Response = response + } + + return result +} diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go new file mode 100644 index 0000000000..b5d98e0f45 --- /dev/null +++ b/dot/parachain/network-bridge/sender_test.go @@ -0,0 +1,203 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package networkbridge + +import ( + "errors" + "testing" + + "github.com/ChainSafe/gossamer/dot/network" + networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" + parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestSendRequests(t *testing.T) { + t.Parallel() + + t.Run("request_succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + request := makeOutgoingRequest(t) + response := &networkbridgemessages.ChunkFetchingResponse{} + expectedValue := networkbridgemessages.NoSuchChunk{} + require.NoError(t, response.SetValue(expectedValue)) + + nbs := setUpNetworkBridgeSender(t, ctrl, request, response, nil, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.NoError(t, result.Error) + + cfResponse, ok := result.Response.(*networkbridgemessages.ChunkFetchingResponse) + require.True(t, ok) + + actualValue, err := cfResponse.Value() + require.NoError(t, err) + + require.Equal(t, expectedValue, actualValue) + requireEmptyAndClosed(t, request.Result) + }) + + t.Run("request_fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + reqErr := errors.New("timeout") + request := makeOutgoingRequest(t) + + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, nil, reqErr) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Equal(t, reqErr, result.Error) + require.Nil(t, result.Response) + requireEmptyAndClosed(t, request.Result) + }) + + t.Run("decoding_fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + request := makeOutgoingRequest(t) + rawResponse := []byte("an invalid network response message") + + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, rawResponse, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Error(t, result.Error) + require.Nil(t, result.Response) + requireEmptyAndClosed(t, request.Result) + }) + + t.Run("cancel_request", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + request := makeOutgoingRequest(t) + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, nil, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + request.Cancel() + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Nil(t, result.Response) + require.NoError(t, result.Error) + requireEmptyAndClosed(t, request.Result) + }) +} + +// We arbitrarily use a ChunkFetchingRequest since it does not matter for testing SendRequests handling. +func makeOutgoingRequest(t *testing.T) *networkbridgemessages.OutgoingRequest { + t.Helper() + + return networkbridgemessages.NewOutgoingRequest( + "recipient", + networkbridgemessages.ChunkFetchingRequest{ + CandidateHash: parachaintypes.CandidateHash{Value: common.Hash{1}}, + Index: 42, + }) +} + +// Expect calls to Network.GetRequestResponseProtocol() and RequestMaker.Do() when only one of response, rawResponse or +// reqErr is non-nil. +// Expect no calls Network.GetRequestResponseProtocol() RequestMaker.Do() when all three are nil. +func setUpNetworkBridgeSender( + t *testing.T, + ctrl *gomock.Controller, + request *networkbridgemessages.OutgoingRequest, + response network.ResponseMessage, + rawResponse []byte, + reqErr error, +) *NetworkBridgeSender { + t.Helper() + + expectCancellation := response == nil && rawResponse == nil && reqErr == nil + reqMaker := NewMockRequestMaker(ctrl) + + if expectCancellation { + reqMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + } else { + reqMaker.EXPECT(). + Do(request.Recipient, request.Payload, gomock.AssignableToTypeOf(request.Payload.Response())). + DoAndReturn(func(to peer.ID, req network.Message, res network.ResponseMessage) error { + if reqErr != nil { + return reqErr + } + + if response != nil { + var err error + rawResponse, err = response.Encode() + require.NoError(t, err) + } + + if err := res.Decode(rawResponse); err != nil { + return err + } + return nil + }) + } + + netService := NewMockNetwork(ctrl) + + if expectCancellation { + netService.EXPECT(). + GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + } else { + netService.EXPECT(). + GetRequestResponseProtocol(request.Payload.Protocol().String(), gomock.Any(), gomock.Any()). + Return(reqMaker) + } + + return RegisterSender(nil, netService) +} + +func requireEmptyAndClosed(t *testing.T, ch chan networkbridgemessages.ReqRespResult) { + select { + case _, ok := <-ch: + require.False(t, ok, "channel was not empty") + default: + t.Error("channel is not closed") + } +} diff --git a/dot/parachain/network_protocols.go b/dot/parachain/network_protocols.go index 880febfbed..ee7434e295 100644 --- a/dot/parachain/network_protocols.go +++ b/dot/parachain/network_protocols.go @@ -10,17 +10,6 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) -type ReqProtocolName uint - -const ( - ChunkFetchingV1 ReqProtocolName = iota - CollationFetchingV1 - PoVFetchingV1 - AvailableDataFetchingV1 - StatementFetchingV1 - DisputeSendingV1 -) - type PeerSetProtocolName uint const ( @@ -28,31 +17,6 @@ const ( CollationProtocolName ) -func GenerateReqProtocolName(protocol ReqProtocolName, forkID string, GenesisHash common.Hash) string { - prefix := fmt.Sprintf("/%s", GenesisHash.String()) - - if forkID != "" { - prefix = fmt.Sprintf("%s/%s", prefix, forkID) - } - - switch protocol { - case ChunkFetchingV1: - return fmt.Sprintf("%s/req_chunk/1", prefix) - case CollationFetchingV1: - return fmt.Sprintf("%s/req_collation/1", prefix) - case PoVFetchingV1: - return fmt.Sprintf("%s/req_pov/1", prefix) - case AvailableDataFetchingV1: - return fmt.Sprintf("%s/req_available_data/1", prefix) - case StatementFetchingV1: - return fmt.Sprintf("%s/req_statement/1", prefix) - case DisputeSendingV1: - return fmt.Sprintf("%s/send_dispute/1", prefix) - default: - panic("unknown protocol") - } -} - func GeneratePeersetProtocolName(protocol PeerSetProtocolName, forkID string, GenesisHash common.Hash, version uint32, ) string { genesisHash := GenesisHash.String() diff --git a/dot/parachain/overseer/overseer.go b/dot/parachain/overseer/overseer.go index 626143c8e3..d0a8aea88e 100644 --- a/dot/parachain/overseer/overseer.go +++ b/dot/parachain/overseer/overseer.go @@ -135,7 +135,7 @@ func (o *OverseerSystem) processMessages() { case networkbridgemessages.DisconnectPeer, networkbridgemessages.ConnectToValidators, networkbridgemessages.ReportPeer, networkbridgemessages.SendCollationMessage, - networkbridgemessages.SendValidationMessage: + networkbridgemessages.SendValidationMessage, networkbridgemessages.SendRequests: subsystem = o.nameToSubsystem[parachaintypes.NetworkBridgeSender] case networkbridgeevents.Event[collatorprotocolmessages.CollationProtocol]: diff --git a/dot/parachain/prospective-parachains/fragment_chain_test.go b/dot/parachain/prospective-parachains/fragment_chain_test.go index 7951565557..d08e44667a 100644 --- a/dot/parachain/prospective-parachains/fragment_chain_test.go +++ b/dot/parachain/prospective-parachains/fragment_chain_test.go @@ -916,7 +916,7 @@ func TestCandidateStorageMethods(t *testing.T) { possibleBackedCandidateHashes = append(possibleBackedCandidateHashes, entry.candidateHash) } - require.Equal(t, []parachaintypes.CandidateHash{candidateHash}, possibleBackedCandidateHashes) + require.Contains(t, possibleBackedCandidateHashes, candidateHash) // now mark it as backed storage.markBacked(candidateHash2) @@ -928,9 +928,10 @@ func TestCandidateStorageMethods(t *testing.T) { possibleBackedCandidateHashes = append(possibleBackedCandidateHashes, entry.candidateHash) } - require.Equal(t, []parachaintypes.CandidateHash{ - candidateHash, candidateHash2}, possibleBackedCandidateHashes) - + // The iterator returned by storage.possibleBackedParaChildren() takes values from a map. + // Therefore we must not assert on the order of elements in possibleBackedCandidateHashes. + require.Contains(t, possibleBackedCandidateHashes, candidateHash) + require.Contains(t, possibleBackedCandidateHashes, candidateHash2) }) }, }, diff --git a/dot/parachain/prospective-parachains/prospective_parachains_test.go b/dot/parachain/prospective-parachains/prospective_parachains_test.go index 806d6dfd17..89137d776e 100644 --- a/dot/parachain/prospective-parachains/prospective_parachains_test.go +++ b/dot/parachain/prospective-parachains/prospective_parachains_test.go @@ -361,10 +361,14 @@ func TestGetMinimumRelayParents(t *testing.T) { BlockNumber: 10, }, } - // Validate the results + result := <-sender - assert.Len(t, result, 2) - assert.Equal(t, expected, result) + assert.Len(t, result, len(expected)) + + // Validate the results without asserting on the order of ParaIDBlockNumber values. + for _, ex := range expected { + assert.Contains(t, result, ex) + } } // TestGetMinimumRelayParents_NoActiveLeaves ensures that getMinimumRelayParents diff --git a/dot/parachain/service.go b/dot/parachain/service.go index 2e86939d2b..14ec9885c7 100644 --- a/dot/parachain/service.go +++ b/dot/parachain/service.go @@ -168,7 +168,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker ReportPeer(change peerset.ReputationChange, p peer.ID) DisconnectPeer(setID int, p peer.ID) GetNetworkEventsChannel() chan *network.NetworkEventInfo