Skip to content

Commit

Permalink
Fix valence_network timeouts + other fixes (#421)
Browse files Browse the repository at this point in the history
# Objective

Make timeouts in `valence_network` per connection rather than per
packet.

# Solution

- Remove timeouts from `PacketIo` and `try_handle_legacy_ping` call.
- Add timeout to `handle_connection` call.
- More consistent variable names.
- Add missing disconnect message when protocol version mismatch occurs.
  • Loading branch information
rj00a authored Jul 24, 2023
1 parent bfaa709 commit 4d5796a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 59 deletions.
71 changes: 40 additions & 31 deletions crates/valence_network/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use valence_core::protocol::encode::PacketEncoder;
use valence_core::protocol::raw::RawBytes;
use valence_core::protocol::var_int::VarInt;
use valence_core::protocol::Decode;
use valence_core::text::{Color, Text};
use valence_core::{ident, translation_key, PROTOCOL_VERSION};
use valence_core::text::{Color, IntoText, Text};
use valence_core::{ident, translation_key, MINECRAFT_VERSION, PROTOCOL_VERSION};

use crate::legacy_ping::try_handle_legacy_ping;
use crate::packet::{
Expand All @@ -46,14 +46,24 @@ pub(super) async fn do_accept_loop(shared: SharedNetworkState) {
}
};

let timeout = Duration::from_secs(5);

loop {
match shared.0.connection_sema.clone().acquire_owned().await {
Ok(permit) => match listener.accept().await {
Ok((stream, remote_addr)) => {
let shared = shared.clone();

tokio::spawn(async move {
handle_connection(shared, stream, remote_addr).await;
if let Err(e) = tokio::time::timeout(
timeout,
handle_connection(shared, stream, remote_addr),
)
.await
{
warn!("initial connection timed out: {e}");
}

drop(permit);
});
}
Expand All @@ -78,26 +88,18 @@ async fn handle_connection(
error!("failed to set TCP_NODELAY: {e}");
}

let timeout = Duration::from_secs(5);

match tokio::time::timeout(
timeout,
try_handle_legacy_ping(&shared, &mut stream, remote_addr),
)
.await
.unwrap_or(Err(io::Error::new(io::ErrorKind::TimedOut, "timed out")))
{
Ok(true) => return,
Ok(false) => {}
match try_handle_legacy_ping(&shared, &mut stream, remote_addr).await {
Ok(true) => return, // Legacy ping succeeded.
Ok(false) => {} // No legacy ping.
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {}
Err(e) => {
warn!("connection ended with error: {e:#}");
warn!("legacy ping ended with error: {e:#}");
}
}

let conn = PacketIo::new(stream, PacketEncoder::new(), PacketDecoder::new(), timeout);
let io = PacketIo::new(stream, PacketEncoder::new(), PacketDecoder::new());

if let Err(e) = handle_handshake(shared, conn, remote_addr).await {
if let Err(e) = handle_handshake(shared, io, remote_addr).await {
// EOF can happen if the client disconnects while joining, which isn't
// very erroneous.
if let Some(e) = e.downcast_ref::<io::Error>() {
Expand Down Expand Up @@ -246,53 +248,60 @@ async fn handle_status(
/// Handle the login process and return the new client's data if successful.
async fn handle_login(
shared: &SharedNetworkState,
conn: &mut PacketIo,
io: &mut PacketIo,
remote_addr: SocketAddr,
handshake: HandshakeData,
) -> anyhow::Result<Option<(NewClientInfo, CleanupOnDrop)>> {
if handshake.protocol_version != PROTOCOL_VERSION {
// TODO: send translated disconnect msg.
io.send_packet(&LoginDisconnectS2c {
// TODO: use correct translation key.
reason: format!("Mismatched Minecraft version (server is on {MINECRAFT_VERSION})")
.color(Color::RED)
.into(),
})
.await?;

return Ok(None);
}

let LoginHelloC2s {
username,
profile_id: _, // TODO
} = conn.recv_packet().await?;
} = io.recv_packet().await?;

let username = username.to_owned();

let info = match shared.connection_mode() {
ConnectionMode::Online { .. } => login_online(shared, conn, remote_addr, username).await?,
ConnectionMode::Online { .. } => login_online(shared, io, remote_addr, username).await?,
ConnectionMode::Offline => login_offline(remote_addr, username)?,
ConnectionMode::BungeeCord => {
login_bungeecord(remote_addr, &handshake.server_address, username)?
}
ConnectionMode::Velocity { secret } => login_velocity(conn, username, secret).await?,
ConnectionMode::Velocity { secret } => login_velocity(io, username, secret).await?,
};

if let Some(threshold) = shared.0.compression_threshold {
conn.send_packet(&LoginCompressionS2c {
io.send_packet(&LoginCompressionS2c {
threshold: VarInt(threshold as i32),
})
.await?;

conn.set_compression(Some(threshold));
io.set_compression(Some(threshold));
}

let cleanup = match shared.0.callbacks.inner.login(shared, &info).await {
Ok(f) => CleanupOnDrop(Some(f)),
Err(reason) => {
info!("disconnect at login: \"{reason}\"");
conn.send_packet(&LoginDisconnectS2c {
io.send_packet(&LoginDisconnectS2c {
reason: reason.into(),
})
.await?;
return Ok(None);
}
};

conn.send_packet(&LoginSuccessS2c {
io.send_packet(&LoginSuccessS2c {
uuid: info.uuid,
username: &info.username,
properties: Default::default(),
Expand All @@ -305,13 +314,13 @@ async fn handle_login(
/// Login procedure for online mode.
async fn login_online(
shared: &SharedNetworkState,
conn: &mut PacketIo,
io: &mut PacketIo,
remote_addr: SocketAddr,
username: String,
) -> anyhow::Result<NewClientInfo> {
let my_verify_token: [u8; 16] = rand::random();

conn.send_packet(&LoginHelloS2c {
io.send_packet(&LoginHelloS2c {
server_id: "", // Always empty
public_key: &shared.0.public_key_der,
verify_token: &my_verify_token,
Expand All @@ -321,7 +330,7 @@ async fn login_online(
let LoginKeyC2s {
shared_secret,
verify_token: encrypted_verify_token,
} = conn.recv_packet().await?;
} = io.recv_packet().await?;

let shared_secret = shared
.0
Expand All @@ -345,7 +354,7 @@ async fn login_online(
.try_into()
.context("shared secret has the wrong length")?;

conn.enable_encryption(&crypt_key);
io.enable_encryption(&crypt_key);

let hash = Sha1::new()
.chain(&shared_secret)
Expand Down Expand Up @@ -373,7 +382,7 @@ async fn login_online(
translation_key::MULTIPLAYER_DISCONNECT_UNVERIFIED_USERNAME,
[],
);
conn.send_packet(&LoginDisconnectS2c {
io.send_packet(&LoginDisconnectS2c {
reason: reason.into(),
})
.await?;
Expand Down
45 changes: 17 additions & 28 deletions crates/valence_network/src/packet_io.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io::ErrorKind;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;
use std::{io, mem};

use anyhow::bail;
Expand All @@ -9,7 +9,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tracing::{debug, warn};
use valence_client::{ClientBundleArgs, ClientConnection, ReceivedPacket};
use valence_core::protocol::decode::{PacketDecoder, PacketFrame};
Expand All @@ -24,18 +23,12 @@ pub(crate) struct PacketIo {
enc: PacketEncoder,
dec: PacketDecoder,
frame: PacketFrame,
timeout: Duration,
}

const READ_BUF_SIZE: usize = 4096;

impl PacketIo {
pub(crate) fn new(
stream: TcpStream,
enc: PacketEncoder,
dec: PacketDecoder,
timeout: Duration,
) -> Self {
pub(crate) fn new(stream: TcpStream, enc: PacketEncoder, dec: PacketDecoder) -> Self {
Self {
stream,
enc,
Expand All @@ -44,7 +37,6 @@ impl PacketIo {
id: -1,
body: BytesMut::new(),
},
timeout,
}
}

Expand All @@ -54,35 +46,32 @@ impl PacketIo {
{
self.enc.append_packet(pkt)?;
let bytes = self.enc.take();
timeout(self.timeout, self.stream.write_all(&bytes)).await??;
self.stream.write_all(&bytes).await?;
Ok(())
}

pub(crate) async fn recv_packet<'a, P>(&'a mut self) -> anyhow::Result<P>
where
P: Packet + Decode<'a>,
{
timeout(self.timeout, async {
loop {
if let Some(frame) = self.dec.try_next_packet()? {
self.frame = frame;
loop {
if let Some(frame) = self.dec.try_next_packet()? {
self.frame = frame;

return self.frame.decode();
}
return self.frame.decode();
}

self.dec.reserve(READ_BUF_SIZE);
let mut buf = self.dec.take_capacity();
self.dec.reserve(READ_BUF_SIZE);
let mut buf = self.dec.take_capacity();

if self.stream.read_buf(&mut buf).await? == 0 {
return Err(io::Error::from(ErrorKind::UnexpectedEof).into());
}

// This should always be an O(1) unsplit because we reserved space earlier and
// the call to `read_buf` shouldn't have grown the allocation.
self.dec.queue_bytes(buf);
if self.stream.read_buf(&mut buf).await? == 0 {
return Err(io::Error::from(ErrorKind::UnexpectedEof).into());
}
})
.await?

// This should always be an O(1) unsplit because we reserved space earlier and
// the call to `read_buf` shouldn't have grown the allocation.
self.dec.queue_bytes(buf);
}
}

#[allow(dead_code)]
Expand Down

0 comments on commit 4d5796a

Please sign in to comment.