From cafef2650aa0a81acb3d0bfd0828800586faf47a Mon Sep 17 00:00:00 2001 From: Kittywhiskers Van Gogh <63189531+kwvg@users.noreply.github.com> Date: Fri, 9 Aug 2024 19:37:21 +0000 Subject: [PATCH] merge bitcoin#28165: transport abstraction --- src/init.cpp | 2 +- src/net.cpp | 209 ++++++++++---- src/net.h | 203 +++++++++----- src/test/denialofservice_tests.cpp | 7 +- src/test/fuzz/p2p_transport_serialization.cpp | 262 +++++++++++++++++- src/test/fuzz/process_messages.cpp | 3 +- src/test/util/net.cpp | 33 ++- src/test/util/net.h | 3 +- 8 files changed, 578 insertions(+), 144 deletions(-) diff --git a/src/init.cpp b/src/init.cpp index 84ed6d4a759271..69c9b5e7924e03 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -584,7 +584,7 @@ void SetupServerArgs(NodeContext& node) argsman.AddArg("-listenonion", strprintf("Automatically create Tor onion service (default: %d)", DEFAULT_LISTEN_ONION), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); argsman.AddArg("-maxconnections=", strprintf("Maintain at most connections to peers (temporary service connections excluded) (default: %u). This limit does not apply to connections manually added via -addnode or the addnode RPC, which have a separate limit of %u.", DEFAULT_MAX_PEER_CONNECTIONS, MAX_ADDNODE_CONNECTIONS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); argsman.AddArg("-maxreceivebuffer=", strprintf("Maximum per-connection receive buffer, *1000 bytes (default: %u)", DEFAULT_MAXRECEIVEBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - argsman.AddArg("-maxsendbuffer=", strprintf("Maximum per-connection send buffer, *1000 bytes (default: %u)", DEFAULT_MAXSENDBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxsendbuffer=", strprintf("Maximum per-connection memory usage for the send buffer, *1000 bytes (default: %u)", DEFAULT_MAXSENDBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); argsman.AddArg("-maxtimeadjustment", strprintf("Maximum allowed median peer time offset adjustment. Local perspective of time may be influenced by peers forward or backward by this amount. (default: %u seconds)", DEFAULT_MAX_TIME_ADJUSTMENT), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); argsman.AddArg("-maxuploadtarget=", strprintf("Tries to keep outbound traffic under the given target (in MiB per 24h). Limit does not apply to peers with 'download' permission. 0 = no limit (default: %d)", DEFAULT_MAX_UPLOAD_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); argsman.AddArg("-onion=", "Use separate SOCKS5 proxy to reach peers via Tor onion services, set -noonion to disable (default: -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); diff --git a/src/net.cpp b/src/net.cpp index 0af067cda135ac..3c490635361506 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -142,6 +143,14 @@ std::map mapLocalHost GUARDED_BY(g_maplocalhost_mute static bool vfLimited[NET_MAX] GUARDED_BY(g_maplocalhost_mutex) = {}; std::string strSubVersion; +size_t CSerializedNetMsg::GetMemoryUsage() const noexcept +{ + // Don't count the dynamic memory used for the m_type string, by assuming it fits in the + // "small string" optimization area (which stores data inside the object itself, up to some + // size; 15 bytes in modern libstdc++). + return sizeof(*this) + memusage::DynamicUsage(data); +} + void CConnman::AddAddrFetch(const std::string& strDest) { LOCK(m_addr_fetches_mutex); @@ -787,16 +796,15 @@ bool CNode::ReceiveMsgBytes(Span msg_bytes, bool& complete) nRecvBytes += msg_bytes.size(); while (msg_bytes.size() > 0) { // absorb network data - int handled = m_deserializer->Read(msg_bytes); - if (handled < 0) { - // Serious header problem, disconnect from the peer. + if (!m_transport->ReceivedBytes(msg_bytes)) { + // Serious transport problem, disconnect from the peer. return false; } - if (m_deserializer->Complete()) { + if (m_transport->ReceivedMessageComplete()) { // decompose a transport agnostic CNetMessage from the deserializer bool reject_message{false}; - CNetMessage msg = m_deserializer->GetMessage(time, reject_message); + CNetMessage msg = m_transport->GetReceivedMessage(time, reject_message); if (reject_message) { // Message deserialization failed. Drop the message but don't disconnect the peer. // store the size of the corrupt message @@ -824,8 +832,18 @@ bool CNode::ReceiveMsgBytes(Span msg_bytes, bool& complete) return true; } -int V1TransportDeserializer::readHeader(Span msg_bytes) +V1Transport::V1Transport(const NodeId node_id, int nTypeIn, int nVersionIn) noexcept : + m_node_id(node_id), hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { + assert(std::size(Params().MessageStart()) == std::size(m_magic_bytes)); + std::copy(std::begin(Params().MessageStart()), std::end(Params().MessageStart()), m_magic_bytes); + LOCK(m_recv_mutex); + Reset(); +} + +int V1Transport::readHeader(Span msg_bytes) +{ + AssertLockHeld(m_recv_mutex); // copy data to temporary parsing buffer unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos; unsigned int nCopy = std::min(nRemaining, msg_bytes.size()); @@ -847,7 +865,7 @@ int V1TransportDeserializer::readHeader(Span msg_bytes) } // Check start string, network magic - if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) { + if (memcmp(hdr.pchMessageStart, m_magic_bytes, CMessageHeader::MESSAGE_START_SIZE) != 0) { LogPrint(BCLog::NET, "Header error: Wrong MessageStart %s received, peer=%d\n", HexStr(hdr.pchMessageStart), m_node_id); return -1; } @@ -864,8 +882,9 @@ int V1TransportDeserializer::readHeader(Span msg_bytes) return nCopy; } -int V1TransportDeserializer::readData(Span msg_bytes) +int V1Transport::readData(Span msg_bytes) { + AssertLockHeld(m_recv_mutex); unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nCopy = std::min(nRemaining, msg_bytes.size()); @@ -881,19 +900,22 @@ int V1TransportDeserializer::readData(Span msg_bytes) return nCopy; } -const uint256& V1TransportDeserializer::GetMessageHash() const +const uint256& V1Transport::GetMessageHash() const { - assert(Complete()); + AssertLockHeld(m_recv_mutex); + assert(CompleteInternal()); if (data_hash.IsNull()) hasher.Finalize(data_hash); return data_hash; } -CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message) +CNetMessage V1Transport::GetReceivedMessage(const std::chrono::microseconds time, bool& reject_message) { + AssertLockNotHeld(m_recv_mutex); // Initialize out parameter reject_message = false; // decompose a single CNetMessage from the TransportDeserializer + LOCK(m_recv_mutex); CNetMessage msg(std::move(vRecv)); // store message type string, time, and sizes @@ -926,53 +948,122 @@ CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds return msg; } -void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const +bool V1Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept { + AssertLockNotHeld(m_send_mutex); + // Determine whether a new message can be set. + LOCK(m_send_mutex); + if (m_sending_header || m_bytes_sent < m_message_to_send.data.size()) return false; + // create dbl-sha256 checksum uint256 hash = Hash(msg.data); // create header - CMessageHeader hdr(Params().MessageStart(), msg.m_type.c_str(), msg.data.size()); + CMessageHeader hdr(m_magic_bytes, msg.m_type.c_str(), msg.data.size()); memcpy(hdr.pchChecksum, hash.begin(), CMessageHeader::CHECKSUM_SIZE); // serialize header - header.reserve(CMessageHeader::HEADER_SIZE); - CVectorWriter{SER_NETWORK, INIT_PROTO_VERSION, header, 0, hdr}; + m_header_to_send.clear(); + CVectorWriter{SER_NETWORK, INIT_PROTO_VERSION, m_header_to_send, 0, hdr}; + + // update state + m_message_to_send = std::move(msg); + m_sending_header = true; + m_bytes_sent = 0; + return true; +} + +Transport::BytesToSend V1Transport::GetBytesToSend() const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + if (m_sending_header) { + return {Span{m_header_to_send}.subspan(m_bytes_sent), + // We have more to send after the header if the message has payload. + !m_message_to_send.data.empty(), + m_message_to_send.m_type + }; + } else { + return {Span{m_message_to_send.data}.subspan(m_bytes_sent), + // We never have more to send after this message's payload. + false, + m_message_to_send.m_type + }; + } +} + +void V1Transport::MarkBytesSent(size_t bytes_sent) noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + m_bytes_sent += bytes_sent; + if (m_sending_header && m_bytes_sent == m_header_to_send.size()) { + // We're done sending a message's header. Switch to sending its data bytes. + m_sending_header = false; + m_bytes_sent = 0; + } else if (!m_sending_header && m_bytes_sent == m_message_to_send.data.size()) { + // We're done sending a message's data. Wipe the data vector to reduce memory consumption. + m_message_to_send.data.clear(); + m_message_to_send.data.shrink_to_fit(); + m_bytes_sent = 0; + } +} + +size_t V1Transport::GetSendMemoryUsage() const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + // Don't count sending-side fields besides m_message_to_send, as they're all small and bounded. + return m_message_to_send.GetMemoryUsage(); } std::pair CConnman::SocketSendData(CNode& node) const { auto it = node.vSendMsg.begin(); size_t nSentSize = 0; - - while (it != node.vSendMsg.end()) { - const auto& data = *it; - assert(data.size() > node.nSendOffset); + bool data_left{false}; //!< second return value (whether unsent data remains) + + while (true) { + if (it != node.vSendMsg.end()) { + // If possible, move one message from the send queue to the transport. This fails when + // there is an existing message still being sent. + size_t memusage = it->GetMemoryUsage(); + if (node.m_transport->SetMessageToSend(*it)) { + // Update memory usage of send buffer (as *it will be deleted). + node.m_send_memusage -= memusage; + ++it; + } + } + const auto& [data, more, msg_type] = node.m_transport->GetBytesToSend(); + data_left = !data.empty(); // will be overwritten on next loop if all of data gets sent int nBytes = 0; - { + if (!data.empty()) { LOCK(node.m_sock_mutex); + // There is no socket in case we've already disconnected, or in test cases without + // real connections. In these cases, we bail out immediately and just leave things + // in the send queue and transport. if (!node.m_sock) { break; } int flags = MSG_NOSIGNAL | MSG_DONTWAIT; #ifdef MSG_MORE - if (it + 1 != node.vSendMsg.end()) { + // We have more to send if either the transport itself has more, or if we have more + // messages to send. + if (more || it != node.vSendMsg.end()) { flags |= MSG_MORE; } #endif - nBytes = node.m_sock->Send(reinterpret_cast(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, flags); + nBytes = node.m_sock->Send(reinterpret_cast(data.data()), data.size(), flags); } if (nBytes > 0) { node.m_last_send = GetTime(); node.nSendBytes += nBytes; - node.nSendOffset += nBytes; + // Notify transport that bytes have been processed. + node.m_transport->MarkBytesSent(nBytes); + // Update statistics per message type. + node.mapSendBytesPerMsgType[msg_type] += nBytes; nSentSize += nBytes; - if (node.nSendOffset == data.size()) { - node.nSendOffset = 0; - node.nSendSize -= data.size(); - node.fPauseSend = node.nSendSize > nSendBufferMaxSize; - it++; - } else { + if ((size_t)nBytes != data.size()) { // could not send full message; stop sending more node.fCanSendData = false; break; @@ -986,19 +1077,18 @@ std::pair CConnman::SocketSendData(CNode& node) const node.fDisconnect = true; } } - // couldn't send anything at all - node.fCanSendData = false; break; } } + node.fPauseSend = node.m_send_memusage + node.m_transport->GetSendMemoryUsage() > nSendBufferMaxSize; + if (it == node.vSendMsg.end()) { - assert(node.nSendOffset == 0); - assert(node.nSendSize == 0); + assert(node.m_send_memusage == 0); } node.vSendMsg.erase(node.vSendMsg.begin(), it); node.nSendMsgSize = node.vSendMsg.size(); - return {nSentSize, !node.vSendMsg.empty()}; + return {nSentSize, data_left}; } static bool ReverseCompareNodeMinPingTime(const NodeEvictionCandidate& a, const NodeEvictionCandidate& b) @@ -1523,7 +1613,9 @@ void CConnman::DisconnectNodes() } if (GetTimeMillis() < pnode->nDisconnectLingerTime) { // everything flushed to the kernel? - if (!pnode->fSocketShutdown && pnode->nSendMsgSize == 0) { + const auto& [to_send, _more, _msg_type] = pnode->m_transport->GetBytesToSend(); + const bool queue_is_empty{to_send.empty() && pnode->nSendMsgSize == 0}; + if (!pnode->fSocketShutdown && queue_is_empty) { LOCK(pnode->m_sock_mutex); if (pnode->m_sock) { // Give the other side a chance to detect the disconnect as early as possible (recv() will return 0) @@ -2088,7 +2180,9 @@ void CConnman::SocketHandlerConnected(const std::set& recv_set, // receiving data. This means properly utilizing TCP flow control signalling. // * Otherwise, if there is space left in the receive buffer (!fPauseRecv), try // receiving data (which should succeed as the socket signalled as receivable). - if (!it->second->fPauseRecv && it->second->nSendMsgSize == 0 && !it->second->fDisconnect) { + const auto& [to_send, _more, _msg_type] = it->second->m_transport->GetBytesToSend(); + const bool queue_is_empty{to_send.empty() && it->second->nSendMsgSize == 0}; + if (!it->second->fPauseRecv && !it->second->fDisconnect && queue_is_empty) { it->second->AddRef(); vReceivableNodes.emplace(it->second); } @@ -2102,7 +2196,8 @@ void CConnman::SocketHandlerConnected(const std::set& recv_set, // but don't have any in this iteration LOCK(cs_mapNodesWithDataToSend); for (auto it = mapNodesWithDataToSend.begin(); it != mapNodesWithDataToSend.end(); ) { - if (it->second->nSendMsgSize == 0) { + const auto& [to_send, _more, _msg_type] = it->second->m_transport->GetBytesToSend(); + if (to_send.empty() && it->second->nSendMsgSize == 0) { // See comment in PushMessage it->second->Release(); it = mapNodesWithDataToSend.erase(it); @@ -4146,8 +4241,7 @@ CNode::CNode(NodeId idIn, ConnectionType conn_type_in, bool inbound_onion, std::unique_ptr&& i2p_sam_session) - : m_deserializer{std::make_unique(V1TransportDeserializer(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION))}, - m_serializer{std::make_unique(V1TransportSerializer())}, + : m_transport{std::make_unique(idIn, SER_NETWORK, INIT_PROTO_VERSION)}, m_sock{sock}, m_connected{GetTime()}, addr{addrIn}, @@ -4186,26 +4280,19 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) if (gArgs.GetBoolArg("-capturemessages", false)) { CaptureMessage(pnode->addr, msg.m_type, msg.data, /* incoming */ false); } - - // make sure we use the appropriate network transport format - std::vector serializedHeader; - pnode->m_serializer->prepareForTransport(msg, serializedHeader); - - size_t nTotalSize = nMessageSize + serializedHeader.size(); - statsClient.count("bandwidth.message." + SanitizeString(msg.m_type.c_str()) + ".bytesSent", nTotalSize, 1.0f); - statsClient.inc("message.sent." + SanitizeString(msg.m_type.c_str()), 1.0f); + statsClient.count(strprintf("bandwidth.message.%s.bytesSent", msg.m_type), nMessageSize, 1.0f); + statsClient.inc(strprintf("message.sent.%s", msg.m_type), 1.0f); { LOCK(pnode->cs_vSend); - bool optimisticSend(pnode->vSendMsg.empty()); - - //log total amount of bytes per message type - pnode->mapSendBytesPerMsgType[msg.m_type] += nTotalSize; - pnode->nSendSize += nTotalSize; - - if (pnode->nSendSize > nSendBufferMaxSize) pnode->fPauseSend = true; - pnode->vSendMsg.push_back(std::move(serializedHeader)); - if (nMessageSize) pnode->vSendMsg.push_back(std::move(msg.data)); + const auto& [to_send, _more, _msg_type] = pnode->m_transport->GetBytesToSend(); + const bool queue_was_empty{to_send.empty() && pnode->vSendMsg.empty()}; + + // Update memory usage of send buffer. + pnode->m_send_memusage += msg.GetMemoryUsage(); + if (pnode->m_send_memusage + pnode->m_transport->GetSendMemoryUsage() > nSendBufferMaxSize) pnode->fPauseSend = true; + // Move message to vSendMsg queue. + pnode->vSendMsg.push_back(std::move(msg)); pnode->nSendMsgSize = pnode->vSendMsg.size(); { @@ -4219,9 +4306,13 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) } } - // wake up select() call in case there was no pending data before (so it was not selecting this socket for sending) - if (optimisticSend && (m_wakeup_pipe && m_wakeup_pipe->m_need_wakeup.load())) - m_wakeup_pipe->Write(); + // Wake up select() call in case there was no pending data before (so it was not selecting + // this socket for sending) + if (queue_was_empty) { + if (m_wakeup_pipe && m_wakeup_pipe->m_need_wakeup.load()) { + m_wakeup_pipe->Write(); + } + } } } diff --git a/src/net.h b/src/net.h index 70dd0d5fab7210..a4d6a7dfcd202d 100644 --- a/src/net.h +++ b/src/net.h @@ -151,6 +151,9 @@ struct CSerializedNetMsg { std::vector data; std::string m_type; + + /** Compute total memory usage of this object (own memory + any dynamic memory). */ + size_t GetMemoryUsage() const noexcept; }; /** Different types of connections to a peer. This enum encapsulates the @@ -350,42 +353,105 @@ class CNetMessage { } }; -/** The TransportDeserializer takes care of holding and deserializing the - * network receive buffer. It can deserialize the network buffer into a - * transport protocol agnostic CNetMessage (message type & payload) - */ -class TransportDeserializer { +/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ +class Transport { public: - // returns true if the current deserialization is complete - virtual bool Complete() const = 0; - // set the serialization context version - virtual void SetVersion(int version) = 0; - /** read and deserialize data, advances msg_bytes data pointer */ - virtual int Read(Span& msg_bytes) = 0; - // decomposes a message from the context - virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0; - virtual ~TransportDeserializer() {} + virtual ~Transport() {} + + // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol + // agnostic CNetMessage (message type & payload) objects. + + /** Returns true if the current message is complete (so GetReceivedMessage can be called). */ + virtual bool ReceivedMessageComplete() const = 0; + /** Set the deserialization context version for objects returned by GetReceivedMessage. */ + virtual void SetReceiveVersion(int version) = 0; + + /** Feed wire bytes to the transport. + * + * @return false if some bytes were invalid, in which case the transport can't be used anymore. + * + * Consumed bytes are chopped off the front of msg_bytes. + */ + virtual bool ReceivedBytes(Span& msg_bytes) = 0; + + /** Retrieve a completed message from transport. + * + * This can only be called when ReceivedMessageComplete() is true. + * + * If reject_message=true is returned the message itself is invalid, but (other than false + * returned by ReceivedBytes) the transport is not in an inconsistent state. + */ + virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; + + // 2. Sending side functions, for converting messages into bytes to be sent over the wire. + + /** Set the next message to send. + * + * If no message can currently be set (perhaps because the previous one is not yet done being + * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and + * possibly moved-from) and true is returned. + */ + virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; + + /** Return type for GetBytesToSend, consisting of: + * - Span to_send: span of bytes to be sent over the wire (possibly empty). + * - bool more: whether there will be more bytes to be sent after the ones in to_send are + * all sent (as signaled by MarkBytesSent()). + * - const std::string& m_type: message type on behalf of which this is being sent. + */ + using BytesToSend = std::tuple< + Span /*to_send*/, + bool /*more*/, + const std::string& /*m_type*/ + >; + + /** Get bytes to send on the wire. + * + * As a const function, it does not modify the transport's observable state, and is thus safe + * to be called multiple times. + * + * The bytes returned by this function act as a stream which can only be appended to. This + * means that with the exception of MarkBytesSent, operations on the transport can only append + * to what is being returned. + * + * Note that m_type and to_send refer to data that is internal to the transport, and calling + * any non-const function on this object may invalidate them. + */ + virtual BytesToSend GetBytesToSend() const noexcept = 0; + + /** Report how many bytes returned by the last GetBytesToSend() have been sent. + * + * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. + * + * If bytes_sent=0, this call has no effect. + */ + virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; + + /** Return the memory usage of this transport attributable to buffered data to send. */ + virtual size_t GetSendMemoryUsage() const noexcept = 0; }; -class V1TransportDeserializer final : public TransportDeserializer +class V1Transport final : public Transport { private: - const CChainParams& m_chain_params; + CMessageHeader::MessageStartChars m_magic_bytes; const NodeId m_node_id; // Only for logging - mutable CHash256 hasher; - mutable uint256 data_hash; - bool in_data; // parsing header (false) or data (true) - CDataStream hdrbuf; // partially received header - CMessageHeader hdr; // complete header - CDataStream vRecv; // received message data - unsigned int nHdrPos; - unsigned int nDataPos; - - const uint256& GetMessageHash() const; - int readHeader(Span msg_bytes); - int readData(Span msg_bytes); - - void Reset() { + mutable Mutex m_recv_mutex; //!< Lock for receive state + mutable CHash256 hasher GUARDED_BY(m_recv_mutex); + mutable uint256 data_hash GUARDED_BY(m_recv_mutex); + bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true) + CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header + CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header + CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data + unsigned int nHdrPos GUARDED_BY(m_recv_mutex); + unsigned int nDataPos GUARDED_BY(m_recv_mutex); + + const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + int readHeader(Span msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + int readData(Span msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + + void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) { + AssertLockHeld(m_recv_mutex); vRecv.clear(); hdrbuf.clear(); hdrbuf.resize(24); @@ -396,52 +462,60 @@ class V1TransportDeserializer final : public TransportDeserializer hasher.Reset(); } -public: - V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) - : m_chain_params(chain_params), - m_node_id(node_id), - hdrbuf(nTypeIn, nVersionIn), - vRecv(nTypeIn, nVersionIn) + bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) { - Reset(); + AssertLockHeld(m_recv_mutex); + if (!in_data) return false; + return hdr.nMessageSize == nDataPos; } - bool Complete() const override + /** Lock for sending state. */ + mutable Mutex m_send_mutex; + /** The header of the message currently being sent. */ + std::vector m_header_to_send GUARDED_BY(m_send_mutex); + /** The data of the message currently being sent. */ + CSerializedNetMsg m_message_to_send GUARDED_BY(m_send_mutex); + /** Whether we're currently sending header bytes or message bytes. */ + bool m_sending_header GUARDED_BY(m_send_mutex) {false}; + /** How many bytes have been sent so far (from m_header_to_send, or from m_message_to_send.data). */ + size_t m_bytes_sent GUARDED_BY(m_send_mutex) {0}; + +public: + V1Transport(const NodeId node_id, int nTypeIn, int nVersionIn) noexcept; + + bool ReceivedMessageComplete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { - if (!in_data) - return false; - return (hdr.nMessageSize == nDataPos); + AssertLockNotHeld(m_recv_mutex); + return WITH_LOCK(m_recv_mutex, return CompleteInternal()); } - void SetVersion(int nVersionIn) override + + void SetReceiveVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); hdrbuf.SetVersion(nVersionIn); vRecv.SetVersion(nVersionIn); } - int Read(Span& msg_bytes) override + + bool ReceivedBytes(Span& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes); if (ret < 0) { Reset(); } else { msg_bytes = msg_bytes.subspan(ret); } - return ret; + return ret >= 0; } - CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override; -}; -/** The TransportSerializer prepares messages for the network transport - */ -class TransportSerializer { -public: - // prepare message for transport (header construction, error-correction computation, payload encryption, etc.) - virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const = 0; - virtual ~TransportSerializer() {} -}; + CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); -class V1TransportSerializer : public TransportSerializer { -public: - void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const override; + bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + BytesToSend GetBytesToSend() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + size_t GetSendMemoryUsage() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); }; /** Information about a peer */ @@ -451,8 +525,9 @@ class CNode friend struct ConnmanTestMsg; public: - const std::unique_ptr m_deserializer; // Used only by SocketHandler thread - const std::unique_ptr m_serializer; + /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv, while + * the sending side functions are only called under cs_vSend. */ + const std::unique_ptr m_transport; NetPermissionFlags m_permissionFlags{NetPermissionFlags::None}; // treated as const outside of fuzz tester @@ -466,12 +541,12 @@ class CNode */ std::shared_ptr m_sock GUARDED_BY(m_sock_mutex); - /** Total size of all vSendMsg entries */ - size_t nSendSize GUARDED_BY(cs_vSend){0}; - /** Offset inside the first vSendMsg already sent */ - size_t nSendOffset GUARDED_BY(cs_vSend){0}; + /** Sum of GetMemoryUsage of all vSendMsg entries. */ + size_t m_send_memusage GUARDED_BY(cs_vSend){0}; + /** Total number of bytes sent on the wire to this peer. */ uint64_t nSendBytes GUARDED_BY(cs_vSend){0}; - std::deque> vSendMsg GUARDED_BY(cs_vSend); + /** Messages still to be fed to m_transport->SetMessageToSend. */ + std::deque vSendMsg GUARDED_BY(cs_vSend); std::atomic nSendMsgSize{0}; Mutex cs_vSend; Mutex m_sock_mutex; diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index bc9daa674468f9..8c5cee5371c2d8 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -89,8 +89,11 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) { LOCK(dummyNode1.cs_vSend); BOOST_CHECK(dummyNode1.vSendMsg.size() > 0); - dummyNode1.vSendMsg.clear(); - dummyNode1.nSendMsgSize = 0; + } + connman.FlushSendBuffer(dummyNode1); + { + LOCK(dummyNode1.cs_vSend); + BOOST_CHECK(dummyNode1.vSendMsg.empty()); } int64_t nStartTime = GetTime(); diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index 8247bbabc43d15..1931562f5a27c4 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include @@ -16,16 +18,21 @@ #include #include +std::vector g_all_messages; + void initialize_p2p_transport_serialization() { SelectParams(CBaseChainParams::REGTEST); + g_all_messages = getAllNetMessageTypes(); + std::sort(g_all_messages.begin(), g_all_messages.end()); } FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serialization) { - // Construct deserializer, with a dummy NodeId - V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; - V1TransportSerializer serializer{}; + // Construct transports for both sides, with dummy NodeIds. + V1Transport recv_transport{NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION}; + V1Transport send_transport{NodeId{1}, SER_NETWORK, INIT_PROTO_VERSION}; + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; auto checksum_assist = fuzzed_data_provider.ConsumeBool(); @@ -62,14 +69,13 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end()); Span msg_bytes{mutable_msg_bytes}; while (msg_bytes.size() > 0) { - const int handled = deserializer.Read(msg_bytes); - if (handled < 0) { + if (!recv_transport.ReceivedBytes(msg_bytes)) { break; } - if (deserializer.Complete()) { + if (recv_transport.ReceivedMessageComplete()) { const std::chrono::microseconds m_time{std::numeric_limits::max()}; bool reject_message{false}; - CNetMessage msg = deserializer.GetMessage(m_time, reject_message); + CNetMessage msg = recv_transport.GetReceivedMessage(m_time, reject_message); assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE); assert(msg.m_raw_message_size <= mutable_msg_bytes.size()); assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size); @@ -77,7 +83,247 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa std::vector header; auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, MakeUCharSpan(msg.m_recv)); - serializer.prepareForTransport(msg2, header); + bool queued = send_transport.SetMessageToSend(msg2); + assert(queued); + std::optional known_more; + while (true) { + const auto& [to_send, more, _msg_type] = send_transport.GetBytesToSend(); + if (known_more) assert(!to_send.empty() == *known_more); + if (to_send.empty()) break; + send_transport.MarkBytesSent(to_send.size()); + known_more = more; + } } } } + +namespace { + +template +void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDataProvider& provider) +{ + // Simulation test with two Transport objects, which send messages to each other, with + // sending and receiving fragmented into multiple pieces that may be interleaved. It primarily + // verifies that the sending and receiving side are compatible with each other, plus a few + // sanity checks. It does not attempt to introduce errors in the communicated data. + + // Put the transports in an array for by-index access. + const std::array transports = {&initiator, &responder}; + + // Two vectors representing in-flight bytes. inflight[i] is from transport[i] to transport[!i]. + std::array, 2> in_flight; + + // Two queues with expected messages. expected[i] is expected to arrive in transport[!i]. + std::array, 2> expected; + + // Vectors with bytes last returned by GetBytesToSend() on transport[i]. + std::array, 2> to_send; + + // Last returned 'more' values (if still relevant) by transport[i]->GetBytesToSend(). + std::array, 2> last_more; + + // Whether more bytes to be sent are expected on transport[i]. + std::array, 2> expect_more; + + // Function to consume a message type. + auto msg_type_fn = [&]() { + uint8_t v = provider.ConsumeIntegral(); + if (v == 0xFF) { + // If v is 0xFF, construct a valid (but possibly unknown) message type from the fuzz + // data. + std::string ret; + while (ret.size() < CMessageHeader::COMMAND_SIZE) { + char c = provider.ConsumeIntegral(); + // Match the allowed characters in CMessageHeader::IsCommandValid(). Any other + // character is interpreted as end. + if (c < ' ' || c > 0x7E) break; + ret += c; + } + return ret; + } else { + // Otherwise, use it as index into the list of known messages. + return g_all_messages[v % g_all_messages.size()]; + } + }; + + // Function to construct a CSerializedNetMsg to send. + auto make_msg_fn = [&](bool first) { + CSerializedNetMsg msg; + if (first) { + // Always send a "version" message as first one. + msg.m_type = "version"; + } else { + msg.m_type = msg_type_fn(); + } + // Determine size of message to send (limited to 75 kB for performance reasons). + size_t size = provider.ConsumeIntegralInRange(0, 75000); + // Get payload of message from RNG. + msg.data.resize(size); + for (auto& v : msg.data) v = uint8_t(rng()); + // Return. + return msg; + }; + + // The next message to be sent (initially version messages, but will be replaced once sent). + std::array next_msg = { + make_msg_fn(/*first=*/true), + make_msg_fn(/*first=*/true) + }; + + // Wrapper around transport[i]->GetBytesToSend() that performs sanity checks. + auto bytes_to_send_fn = [&](int side) -> Transport::BytesToSend { + const auto& [bytes, more, msg_type] = transports[side]->GetBytesToSend(); + // Compare with expected more. + if (expect_more[side].has_value()) assert(!bytes.empty() == *expect_more[side]); + // Compare with previously reported output. + assert(to_send[side].size() <= bytes.size()); + assert(to_send[side] == Span{bytes}.first(to_send[side].size())); + to_send[side].resize(bytes.size()); + std::copy(bytes.begin(), bytes.end(), to_send[side].begin()); + // Remember 'more' result. + last_more[side] = {more}; + // Return. + return {bytes, more, msg_type}; + }; + + // Function to make side send a new message. + auto new_msg_fn = [&](int side) { + // Don't do anything if there are too many unreceived messages already. + if (expected[side].size() >= 16) return; + // Try to send (a copy of) the message in next_msg[side]. + CSerializedNetMsg msg = next_msg[side].Copy(); + bool queued = transports[side]->SetMessageToSend(msg); + // Update expected more data. + expect_more[side] = std::nullopt; + // Verify consistency of GetBytesToSend after SetMessageToSend + bytes_to_send_fn(/*side=*/side); + if (queued) { + // Remember that this message is now expected by the receiver. + expected[side].emplace_back(std::move(next_msg[side])); + // Construct a new next message to send. + next_msg[side] = make_msg_fn(/*first=*/false); + } + }; + + // Function to make side send out bytes (if any). + auto send_fn = [&](int side, bool everything = false) { + const auto& [bytes, more, msg_type] = bytes_to_send_fn(/*side=*/side); + // Don't do anything if no bytes to send. + if (bytes.empty()) return false; + size_t send_now = everything ? bytes.size() : provider.ConsumeIntegralInRange(0, bytes.size()); + if (send_now == 0) return false; + // Add bytes to the in-flight queue, and mark those bytes as consumed. + in_flight[side].insert(in_flight[side].end(), bytes.begin(), bytes.begin() + send_now); + transports[side]->MarkBytesSent(send_now); + // If all to-be-sent bytes were sent, move last_more data to expect_more data. + if (send_now == bytes.size()) { + expect_more[side] = last_more[side]; + } + // Remove the bytes from the last reported to-be-sent vector. + assert(to_send[side].size() >= send_now); + to_send[side].erase(to_send[side].begin(), to_send[side].begin() + send_now); + // Verify that GetBytesToSend gives a result consistent with earlier. + bytes_to_send_fn(/*side=*/side); + // Return whether anything was sent. + return send_now > 0; + }; + + // Function to make !side receive bytes (if any). + auto recv_fn = [&](int side, bool everything = false) { + // Don't do anything if no bytes in flight. + if (in_flight[side].empty()) return false; + // Decide span to receive + size_t to_recv_len = in_flight[side].size(); + if (!everything) to_recv_len = provider.ConsumeIntegralInRange(0, to_recv_len); + Span to_recv = Span{in_flight[side]}.first(to_recv_len); + // Process those bytes + while (!to_recv.empty()) { + size_t old_len = to_recv.size(); + bool ret = transports[!side]->ReceivedBytes(to_recv); + // Bytes must always be accepted, as this test does not introduce any errors in + // communication. + assert(ret); + // Clear cached expected 'more' information: if certainly no more data was to be sent + // before, receiving bytes makes this uncertain. + if (expect_more[!side] == false) expect_more[!side] = std::nullopt; + // Verify consistency of GetBytesToSend after ReceivedBytes + bytes_to_send_fn(/*side=*/!side); + bool progress = to_recv.size() < old_len; + if (transports[!side]->ReceivedMessageComplete()) { + bool reject{false}; + auto received = transports[!side]->GetReceivedMessage({}, reject); + // Receiving must succeed. + assert(!reject); + // There must be a corresponding expected message. + assert(!expected[side].empty()); + // The m_message_size field must be correct. + assert(received.m_message_size == received.m_recv.size()); + // The m_type must match what is expected. + assert(received.m_type == expected[side].front().m_type); + // The data must match what is expected. + assert(MakeByteSpan(received.m_recv) == MakeByteSpan(expected[side].front().data)); + expected[side].pop_front(); + progress = true; + } + // Progress must be made (by processing incoming bytes and/or returning complete + // messages) until all received bytes are processed. + assert(progress); + } + // Remove the processed bytes from the in_flight buffer. + in_flight[side].erase(in_flight[side].begin(), in_flight[side].begin() + to_recv_len); + // Return whether anything was received. + return to_recv_len > 0; + }; + + // Main loop, interleaving new messages, sends, and receives. + LIMITED_WHILE(provider.remaining_bytes(), 1000) { + CallOneOf(provider, + // (Try to) give the next message to the transport. + [&] { new_msg_fn(/*side=*/0); }, + [&] { new_msg_fn(/*side=*/1); }, + // (Try to) send some bytes from the transport to the network. + [&] { send_fn(/*side=*/0); }, + [&] { send_fn(/*side=*/1); }, + // (Try to) receive bytes from the network, converting to messages. + [&] { recv_fn(/*side=*/0); }, + [&] { recv_fn(/*side=*/1); } + ); + } + + // When we're done, perform sends and receives of existing messages to flush anything already + // in flight. + while (true) { + bool any = false; + if (send_fn(/*side=*/0, /*everything=*/true)) any = true; + if (send_fn(/*side=*/1, /*everything=*/true)) any = true; + if (recv_fn(/*side=*/0, /*everything=*/true)) any = true; + if (recv_fn(/*side=*/1, /*everything=*/true)) any = true; + if (!any) break; + } + + // Make sure nothing is left in flight. + assert(in_flight[0].empty()); + assert(in_flight[1].empty()); + + // Make sure all expected messages were received. + assert(expected[0].empty()); + assert(expected[1].empty()); +} + +std::unique_ptr MakeV1Transport(NodeId nodeid) noexcept +{ + return std::make_unique(nodeid, SER_NETWORK, INIT_PROTO_VERSION); +} + +} // namespace + +FUZZ_TARGET_INIT(p2p_transport_bidirectional, initialize_p2p_transport_serialization) +{ + // Test with two V1 transports talking to each other. + FuzzedDataProvider provider{buffer.data(), buffer.size()}; + XoRoShiRo128PlusPlus rng(provider.ConsumeIntegral()); + auto t1 = MakeV1Transport(NodeId{0}); + auto t2 = MakeV1Transport(NodeId{1}); + if (!t1 || !t2) return; + SimulationTest(*t1, *t2, rng, provider); +} diff --git a/src/test/fuzz/process_messages.cpp b/src/test/fuzz/process_messages.cpp index 6e270f8d60f7c6..938f189db1e17f 100644 --- a/src/test/fuzz/process_messages.cpp +++ b/src/test/fuzz/process_messages.cpp @@ -65,7 +65,8 @@ FUZZ_TARGET_INIT(process_messages, initialize_process_messages) CNode& random_node = *PickValue(fuzzed_data_provider, peers); - (void)connman.ReceiveMsgFrom(random_node, net_msg); + connman.FlushSendBuffer(random_node); + (void)connman.ReceiveMsgFrom(random_node, std::move(net_msg)); random_node.fPauseSend = false; try { diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index b636559c4ae0c9..fc6a29d4243b4a 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -25,6 +25,7 @@ void ConnmanTestMsg::Handshake(CNode& node, const CNetMsgMaker mm{0}; peerman.InitializeNode(node, local_services); + FlushSendBuffer(node); // Drop the version message added by InitializeNode. CSerializedNetMsg msg_version{ mm.Make(NetMsgType::VERSION, @@ -41,10 +42,11 @@ void ConnmanTestMsg::Handshake(CNode& node, relay_txs), }; - (void)connman.ReceiveMsgFrom(node, msg_version); + (void)connman.ReceiveMsgFrom(node, std::move(msg_version)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); + FlushSendBuffer(node); // Drop the verack message added by SendMessages. if (node.fDisconnect) return; assert(node.nVersion == version); assert(node.GetCommonVersion() == std::min(version, PROTOCOL_VERSION)); @@ -55,7 +57,7 @@ void ConnmanTestMsg::Handshake(CNode& node, node.m_permissionFlags = permission_flags; if (successfully_connected) { CSerializedNetMsg msg_verack{mm.Make(NetMsgType::VERACK)}; - (void)connman.ReceiveMsgFrom(node, msg_verack); + (void)connman.ReceiveMsgFrom(node, std::move(msg_verack)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); @@ -83,14 +85,29 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span msg_by } } -bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const +void ConnmanTestMsg::FlushSendBuffer(CNode& node) const { - std::vector ser_msg_header; - node.m_serializer->prepareForTransport(ser_msg, ser_msg_header); + LOCK(node.cs_vSend); + node.vSendMsg.clear(); + node.m_send_memusage = 0; + while (true) { + const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(); + if (to_send.empty()) break; + node.m_transport->MarkBytesSent(to_send.size()); + } +} - bool complete; - NodeReceiveMsgBytes(node, ser_msg_header, complete); - NodeReceiveMsgBytes(node, ser_msg.data, complete); +bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const +{ + bool queued = node.m_transport->SetMessageToSend(ser_msg); + assert(queued); + bool complete{false}; + while (true) { + const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(); + if (to_send.empty()) break; + NodeReceiveMsgBytes(node, to_send, complete); + node.m_transport->MarkBytesSent(to_send.size()); + } return complete; } diff --git a/src/test/util/net.h b/src/test/util/net.h index 1f2a31290e3b99..dc758bd43ed195 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -51,7 +51,8 @@ struct ConnmanTestMsg : public CConnman { void NodeReceiveMsgBytes(CNode& node, Span msg_bytes, bool& complete) const; - bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const; + bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const; + void FlushSendBuffer(CNode& node) const; }; constexpr ServiceFlags ALL_SERVICE_FLAGS[]{