Skip to content

Commit

Permalink
refactor: tr_peer_socket keeps track of peer count (transmission#4534)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckerr authored Jan 4, 2023
1 parent c95891e commit b47c347
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 68 deletions.
44 changes: 16 additions & 28 deletions libtransmission/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,8 @@ void tr_netSetCongestionControl([[maybe_unused]] tr_socket_t s, [[maybe_unused]]
#endif
}

static tr_socket_t createSocket(tr_session* session, int domain, int type)
static tr_socket_t createSocket(int domain, int type)
{
TR_ASSERT(session != nullptr);

auto const sockfd = socket(domain, type, 0);
if (sockfd == TR_BAD_SOCKET)
{
Expand All @@ -160,9 +158,9 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type)
return TR_BAD_SOCKET;
}

if ((evutil_make_socket_nonblocking(sockfd) == -1) || !session->incPeerCount())
if (evutil_make_socket_nonblocking(sockfd) == -1)
{
tr_netClose(session, sockfd);
tr_net_close_socket(sockfd);
return TR_BAD_SOCKET;
}

Expand Down Expand Up @@ -193,19 +191,15 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type)
tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed)
{
TR_ASSERT(addr.is_valid());
TR_ASSERT(!tr_peer_socket::limit_reached(session));

if (!session->allowsTCP())
{
return {};
}

if (!addr.is_valid_for_peers(port))
if (tr_peer_socket::limit_reached(session) || !session->allowsTCP() || !addr.is_valid_for_peers(port))
{
return {};
}

static auto constexpr Domains = std::array<int, NUM_TR_AF_INET_TYPES>{ AF_INET, AF_INET6 };
auto const s = createSocket(session, Domains[addr.type], SOCK_STREAM);
auto const s = createSocket(Domains[addr.type], SOCK_STREAM);
if (s == TR_BAD_SOCKET)
{
return {};
Expand Down Expand Up @@ -236,7 +230,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr,
fmt::arg("socket", s),
fmt::arg("error", tr_net_strerror(sockerrno)),
fmt::arg("error_code", sockerrno)));
tr_netClose(session, s);
tr_net_close_socket(s);
return {};
}

Expand All @@ -258,7 +252,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr,
fmt::arg("error_code", tmperrno)));
}

tr_netClose(session, s);
tr_net_close_socket(s);
}
else
{
Expand Down Expand Up @@ -286,7 +280,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
if (evutil_make_socket_nonblocking(fd) == -1)
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}

Expand All @@ -301,7 +295,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
(sockerrno != ENOPROTOOPT)) // if the kernel doesn't support it, ignore it
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}

Expand All @@ -325,7 +319,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
fmt::arg("error_code", err)));
}

tr_netCloseSocket(fd);
tr_net_close_socket(fd);
*err_out = err;
return TR_BAD_SOCKET;
}
Expand Down Expand Up @@ -354,7 +348,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
#endif /* _WIN32 */
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}

Expand Down Expand Up @@ -384,7 +378,7 @@ bool tr_net_hasIPv6(tr_port port)

if (fd != TR_BAD_SOCKET)
{
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
}

already_done = true;
Expand All @@ -410,26 +404,20 @@ std::optional<std::tuple<tr_address, tr_port, tr_socket_t>> tr_netAccept(tr_sess
// make the socket unblocking,
// and confirm we don't have too many peers
auto const addrport = tr_address::from_sockaddr(reinterpret_cast<struct sockaddr*>(&sock));
if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || !session->incPeerCount())
if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || tr_peer_socket::limit_reached(session))
{
tr_netCloseSocket(sockfd);
tr_net_close_socket(sockfd);
return {};
}

return std::make_tuple(addrport->first, addrport->second, sockfd);
}

void tr_netCloseSocket(tr_socket_t sockfd)
void tr_net_close_socket(tr_socket_t sockfd)
{
evutil_closesocket(sockfd);
}

void tr_netClose(tr_session* session, tr_socket_t sockfd)
{
tr_netCloseSocket(sockfd);
session->decPeerCount();
}

// code in global_ipv6_herlpers is written by Juliusz Chroboczek
// and is covered under the same license as dht.cc.
// Please feel free to copy them into your software if it can help
Expand Down
4 changes: 1 addition & 3 deletions libtransmission/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,7 @@ tr_socket_t tr_netBindTCP(tr_address const& addr, tr_port port, bool suppress_ms

void tr_netSetCongestionControl(tr_socket_t s, char const* algorithm);

void tr_netClose(tr_session* session, tr_socket_t s);

void tr_netCloseSocket(tr_socket_t fd);
void tr_net_close_socket(tr_socket_t fd);

bool tr_net_hasIPv6(tr_port);

Expand Down
8 changes: 7 additions & 1 deletion libtransmission/peer-io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ std::shared_ptr<tr_peerIo> tr_peerIo::new_outgoing(
bool is_seed,
bool utp)
{
TR_ASSERT(!tr_peer_socket::limit_reached(session));
TR_ASSERT(session != nullptr);
TR_ASSERT(addr.is_valid());
TR_ASSERT(utp || session->allowsTCP());
Expand Down Expand Up @@ -166,7 +167,7 @@ void tr_peerIo::set_socket(tr_peer_socket socket_in)

void tr_peerIo::close()
{
socket_.close(session_);
socket_.close();
event_write_.reset();
event_read_.reset();
}
Expand All @@ -189,6 +190,11 @@ bool tr_peerIo::reconnect()

close();

if (tr_peer_socket::limit_reached(session_))
{
return false;
}

auto const [addr, port] = socket_address();
socket_ = tr_netOpenPeerSocket(session_, addr, port, is_seed());

Expand Down
14 changes: 8 additions & 6 deletions libtransmission/peer-mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1198,11 +1198,11 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
if (session->addressIsBlocked(socket.address()))
{
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.display_name()));
socket.close(session);
socket.close();
}
else if (manager->incoming_handshakes.count(socket.address()) != 0U)
{
socket.close(session);
socket.close();
}
else /* we don't have a connection to them yet... */
{
Expand Down Expand Up @@ -2726,7 +2726,9 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
utp = utp && (atom.flags & ADDED_F_UTP_FLAGS) != 0;
}

if (!utp && !mgr->session->allowsTCP())
auto* const session = mgr->session;

if (tr_peer_socket::limit_reached(session) || (!utp && !session->allowsTCP()))
{
return;
}
Expand All @@ -2736,8 +2738,8 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.display_name()));

auto peer_io = tr_peerIo::new_outgoing(
mgr->session,
&mgr->session->top_bandwidth_,
session,
&session->top_bandwidth_,
atom.addr,
atom.port,
s->tor->infoHash(),
Expand All @@ -2756,7 +2758,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
atom.addr,
&mgr->handshake_mediator_,
peer_io,
mgr->session->encryptionMode(),
session->encryptionMode(),
[mgr](tr_handshake::Result const& result) { return on_handshake_done(mgr, result); });
}

Expand Down
14 changes: 12 additions & 2 deletions libtransmission/peer-socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tr_peer_socket::tr_peer_socket(tr_session const* session, tr_address const& addr
{
TR_ASSERT(sock != TR_BAD_SOCKET);

++n_open_sockets_;
session->setSocketTOS(sock, address_.type);

if (auto const& algo = session->peerCongestionAlgorithm(); !std::empty(algo))
Expand All @@ -42,20 +43,24 @@ tr_peer_socket::tr_peer_socket(tr_address const& address, tr_port port, struct U
, type_{ Type::UTP }
{
TR_ASSERT(sock != nullptr);

++n_open_sockets_;
handle.utp = sock;

tr_logAddTraceIo(this, fmt::format("socket (µTP) is {}", fmt::ptr(handle.utp)));
}

void tr_peer_socket::close(tr_session* session)
void tr_peer_socket::close()
{
if (is_tcp() && (handle.tcp != TR_BAD_SOCKET))
{
tr_netClose(session, handle.tcp);
--n_open_sockets_;
tr_net_close_socket(handle.tcp);
}
#ifdef WITH_UTP
else if (is_utp())
{
--n_open_sockets_;
utp_set_userdata(handle.utp, nullptr);
utp_close(handle.utp);
}
Expand Down Expand Up @@ -126,3 +131,8 @@ size_t tr_peer_socket::try_read(Buffer& buf, size_t max, tr_error** error) const

return {};
}

bool tr_peer_socket::limit_reached(tr_session* const session) noexcept
{
return n_open_sockets_.load() >= session->peerLimit();
}
7 changes: 6 additions & 1 deletion libtransmission/peer-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#error only libtransmission should #include this header.
#endif

#include <atomic>
#include <string>
#include <string_view>
#include <utility> // for std::make_pair()
Expand Down Expand Up @@ -37,7 +38,7 @@ class tr_peer_socket
tr_peer_socket& operator=(tr_peer_socket const&) = delete;
~tr_peer_socket() = default;

void close(tr_session* session);
void close();

size_t try_write(Buffer& buf, size_t max, tr_error** error) const;
size_t try_read(Buffer& buf, size_t max, tr_error** error) const;
Expand Down Expand Up @@ -124,6 +125,8 @@ class tr_peer_socket
struct UTPSocket* utp;
} handle = {};

[[nodiscard]] static bool limit_reached(tr_session* const session) noexcept;

private:
enum class Type
{
Expand All @@ -136,6 +139,8 @@ class tr_peer_socket
tr_port port_;

enum Type type_ = Type::None;

static inline std::atomic<size_t> n_open_sockets_ = {};
};

tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed);
2 changes: 1 addition & 1 deletion libtransmission/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ tr_session::BoundSocket::~BoundSocket()

if (socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(socket_);
tr_net_close_socket(socket_);
socket_ = TR_BAD_SOCKET;
}
}
Expand Down
21 changes: 0 additions & 21 deletions libtransmission/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,25 +496,6 @@ struct tr_session
return settings_.peer_limit_per_torrent;
}

[[nodiscard]] constexpr bool incPeerCount() noexcept
{
if (this->peer_count_ >= this->peerLimit())
{
return false;
}

++this->peer_count_;
return true;
}

constexpr void decPeerCount() noexcept
{
if (this->peer_count_ > 0)
{
--this->peer_count_;
}
}

// bandwidth

[[nodiscard]] tr_bandwidth& getBandwidthGroup(std::string_view name);
Expand Down Expand Up @@ -1059,8 +1040,6 @@ struct tr_session
// port than the one requested by Transmission.
tr_port advertised_peer_port_;

uint16_t peer_count_ = 0;

bool is_closing_ = false;

/// fields that aren't trivial,
Expand Down
8 changes: 4 additions & 4 deletions libtransmission/tr-udp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
fmt::arg("error", tr_strerror(error_code)),
fmt::arg("error_code", error_code)));

tr_netCloseSocket(sock);
tr_net_close_socket(sock);
}
else
{
Expand Down Expand Up @@ -193,7 +193,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
fmt::arg("error", tr_strerror(error_code)),
fmt::arg("error_code", error_code)));

tr_netCloseSocket(sock);
tr_net_close_socket(sock);
}
else
{
Expand All @@ -220,15 +220,15 @@ tr_session::tr_udp_core::~tr_udp_core()

if (udp6_socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(udp6_socket_);
tr_net_close_socket(udp6_socket_);
udp6_socket_ = TR_BAD_SOCKET;
}

udp4_event_.reset();

if (udp4_socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(udp4_socket_);
tr_net_close_socket(udp4_socket_);
udp4_socket_ = TR_BAD_SOCKET;
}
}
Expand Down
2 changes: 1 addition & 1 deletion libtransmission/tr-utp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ static void utp_on_accept(tr_session* const session, UTPSocket* const utp_sock)
auto* const from = (struct sockaddr*)&from_storage;
socklen_t fromlen = sizeof(from_storage);

if (!session->allowsUTP())
if (!session->allowsUTP() || tr_peer_socket::limit_reached(session))
{
utp_close(utp_sock);
return;
Expand Down

0 comments on commit b47c347

Please sign in to comment.