Skip to content

Commit

Permalink
Merge pull request #85 from nyonson/switch-to-mutable-io-refs
Browse files Browse the repository at this point in the history
Switch to mutable refs for async io
  • Loading branch information
rustaceanrob authored Oct 30, 2024
2 parents d37c1a3 + 98fe548 commit df1f835
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 55 deletions.
92 changes: 44 additions & 48 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1070,21 +1070,13 @@ impl fmt::Display for ProtocolError {

/// A protocol session with handshake and send/receive packet management.
#[cfg(any(feature = "async", feature = "tokio"))]
pub struct AsyncProtocol<R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
reader: AsyncProtocolReader<R>,
writer: AsyncProtocolWriter<W>,
pub struct AsyncProtocol {
reader: AsyncProtocolReader,
writer: AsyncProtocolWriter,
}

#[cfg(any(feature = "async", feature = "tokio"))]
impl<R, W> AsyncProtocol<R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
impl AsyncProtocol {
/// New protocol session which completes the initial handshake and returns a handler.
///
/// # Arguments
Expand All @@ -1105,14 +1097,18 @@ where
/// # Errors
///
/// * `Io` - Includes a flag for if the remote probably only understands the V1 protocol.
pub async fn new<'a>(
pub async fn new<'a, R, W>(
network: Network,
role: Role,
garbage: Option<&'a [u8]>,
decoys: Option<&'a [&'a [u8]]>,
mut reader: R,
mut writer: W,
) -> Result<Self, ProtocolError> {
reader: &mut R,
writer: &mut W,
) -> Result<Self, ProtocolError>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
let garbage_len = match garbage {
Some(slice) => slice.len(),
None => 0,
Expand Down Expand Up @@ -1190,29 +1186,25 @@ where

Ok(Self {
reader: AsyncProtocolReader {
buffer: reader,
packet_reader,
state: DecryptState::default(),
},
writer: AsyncProtocolWriter {
buffer: writer,
packet_writer,
},
writer: AsyncProtocolWriter { packet_writer },
})
}

/// Read reference for packet reading operations.
pub fn reader(&mut self) -> &mut AsyncProtocolReader<R> {
pub fn reader(&mut self) -> &mut AsyncProtocolReader {
&mut self.reader
}

/// Write reference for packet writing operations.
pub fn writer(&mut self) -> &mut AsyncProtocolWriter<W> {
pub fn writer(&mut self) -> &mut AsyncProtocolWriter {
&mut self.writer
}

/// Split the protocol into a separate reader and writer.
pub fn into_split(self) -> (AsyncProtocolReader<R>, AsyncProtocolWriter<W>) {
pub fn into_split(self) -> (AsyncProtocolReader, AsyncProtocolWriter) {
(self.reader, self.writer)
}
}
Expand Down Expand Up @@ -1243,30 +1235,30 @@ impl Default for DecryptState {

/// Manages an async buffer to automatically decrypt contents of received packets.
#[cfg(any(feature = "async", feature = "tokio"))]
pub struct AsyncProtocolReader<R>
where
R: AsyncRead + Unpin + Send,
{
buffer: R,
pub struct AsyncProtocolReader {
packet_reader: PacketReader,
state: DecryptState,
}

#[cfg(any(feature = "async", feature = "tokio"))]
impl<R> AsyncProtocolReader<R>
where
R: AsyncRead + Unpin + Send,
{
impl AsyncProtocolReader {
/// Decrypt contents of received packet from buffer.
///
/// This function is cancellation safe.
///
/// # Arguments
///
/// * `buffer` - Asynchronous I/O buffer to pull bytes from.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(Payload)`: A decrypted payload.
/// * `Err(ProtocolError)`: An error that occurred during the read or decryption.
pub async fn decrypt(&mut self) -> Result<Payload, ProtocolError> {
pub async fn read_and_decrypt<R>(&mut self, buffer: &mut R) -> Result<Payload, ProtocolError>
where
R: AsyncRead + Unpin + Send,
{
// Storing state between async reads to make function cancellation safe.
loop {
match &mut self.state {
Expand All @@ -1275,7 +1267,7 @@ where
bytes_read,
} => {
while *bytes_read < 3 {
*bytes_read += self.buffer.read(&mut length_bytes[*bytes_read..]).await?;
*bytes_read += buffer.read(&mut length_bytes[*bytes_read..]).await?;
}

let packet_bytes_len = self.packet_reader.decypt_len(*length_bytes);
Expand All @@ -1290,7 +1282,7 @@ where
bytes_read,
} => {
while *bytes_read < packet_bytes.len() {
*bytes_read += self.buffer.read(&mut packet_bytes[*bytes_read..]).await?;
*bytes_read += buffer.read(&mut packet_bytes[*bytes_read..]).await?;
}

let payload = self.packet_reader.decrypt_payload(packet_bytes, None)?;
Expand All @@ -1304,32 +1296,36 @@ where

/// Manages an async buffer to automatically encrypt and send contents in packets.
#[cfg(any(feature = "async", feature = "tokio"))]
pub struct AsyncProtocolWriter<W>
where
W: AsyncWrite + Unpin + Send,
{
buffer: W,
pub struct AsyncProtocolWriter {
packet_writer: PacketWriter,
}

#[cfg(any(feature = "async", feature = "tokio"))]
impl<W> AsyncProtocolWriter<W>
where
W: AsyncWrite + Unpin + Send,
{
impl AsyncProtocolWriter {
/// Encrypt contents and write packet buffer.
///
/// # Arguments
///
/// * `buffer` - Asynchronous I/O buffer to write bytes to.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok()`: On successful contents encryption and packet send.
/// * `Err(ProtocolError)`: An error that occurred during the encryption or write.
pub async fn encrypt(&mut self, plaintext: &[u8]) -> Result<(), ProtocolError> {
pub async fn encrypt_and_write<W>(
&mut self,
plaintext: &[u8],
buffer: &mut W,
) -> Result<(), ProtocolError>
where
W: AsyncWrite + Unpin + Send,
{
let write_bytes =
self.packet_writer
.encrypt_packet(plaintext, None, PacketType::Genuine)?;
self.buffer.write_all(&write_bytes[..]).await?;
self.buffer.flush().await?;
buffer.write_all(&write_bytes[..]).await?;
buffer.flush().await?;
Ok(())
}
}
Expand Down
14 changes: 7 additions & 7 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ async fn v2_proxy(
.expect("connect to remote");

info!("Initiating handshake.");
let (remote_reader, remote_writer) = remote.into_split();
let (mut remote_reader, mut remote_writer) = remote.into_split();

let protocol = match AsyncProtocol::new(
network,
Role::Initiator,
None,
None,
remote_reader,
remote_writer,
&mut remote_reader,
&mut remote_writer,
)
.await
{
Expand All @@ -95,7 +95,7 @@ async fn v2_proxy(
let mut v1_client_reader = V1ProtocolReader::new(client_reader);
let mut v1_client_writer = V1ProtocolWriter::new(network, client_writer);

let (mut remote_reader, mut remote_writer) = protocol.into_split();
let (mut v2_remote_reader, mut v2_remote_writer) = protocol.into_split();

info!("Setting up V2 proxy.");

Expand All @@ -109,12 +109,12 @@ async fn v2_proxy(
);

let contents = serialize(msg).expect("serialize-able contents into network message");
remote_writer
.encrypt(&contents)
v2_remote_writer
.encrypt_and_write(&contents, &mut remote_writer)
.await
.expect("write to remote");
},
result = remote_reader.decrypt() => {
result = v2_remote_reader.read_and_decrypt(&mut remote_reader) => {
let payload = result.expect("read packet");
// Ignore decoy packets.
if payload.packet_type() == PacketType::Genuine {
Expand Down

0 comments on commit df1f835

Please sign in to comment.