From 5f1d0e5a76a0e27b3484a4b7585d1d549d830aaf Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 27 May 2024 22:46:03 +0200 Subject: [PATCH 01/23] frame: Use abstract trait to write the header to --- src/frame.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/frame.rs b/src/frame.rs index 9ac133d..0a34ce6 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -14,7 +14,7 @@ use tokio::io::AsyncWriteExt; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut}; use core::ops::Deref; use crate::WebSocketError; @@ -259,26 +259,27 @@ impl<'f> Frame<'f> { /// # Panics /// /// This method panics if the head buffer is not at least n-bytes long, where n is the size of the length field (0, 2, 4, or 10) - pub fn fmt_head(&mut self, head: &mut [u8]) -> usize { - head[0] = (self.fin as u8) << 7 | (self.opcode as u8); + pub fn fmt_head(&mut self, mut head: impl BufMut) -> usize { + head.put_u8((self.fin as u8) << 7 | (self.opcode as u8)); + + let mask_bit = if self.mask.is_some() { 0x80 } else { 0x0 }; let len = self.payload.len(); let size = if len < 126 { - head[1] = len as u8; + head.put_u8(len as u8 | mask_bit); 2 } else if len < 65536 { - head[1] = 126; - head[2..4].copy_from_slice(&(len as u16).to_be_bytes()); + head.put_u8(126u8 | mask_bit); + head.put_slice(&(len as u16).to_be_bytes()); 4 } else { - head[1] = 127; - head[2..10].copy_from_slice(&(len as u64).to_be_bytes()); + head.put_u8(127u8 | mask_bit); + head.put_slice(&(len as u64).to_be_bytes()); 10 }; if let Some(mask) = self.mask { - head[1] |= 0x80; - head[size..size + 4].copy_from_slice(&mask); + head.put_slice(&mask); size + 4 } else { size @@ -295,7 +296,7 @@ impl<'f> Frame<'f> { use std::io::IoSlice; let mut head = [0; MAX_HEAD_SIZE]; - let size = self.fmt_head(&mut head); + let size = self.fmt_head(&mut head[..]); let total = size + self.payload.len(); @@ -330,7 +331,7 @@ impl<'f> Frame<'f> { let len = self.payload.len(); reserve_enough(buf, len + MAX_HEAD_SIZE); - let size = self.fmt_head(buf); + let size = self.fmt_head(&mut *buf); buf[size..size + len].copy_from_slice(&self.payload); &buf[..size + len] } From 9ea19e614f76a9ed5ca500771afe054a84e44fd3 Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 27 May 2024 22:46:50 +0200 Subject: [PATCH 02/23] Implement poll* methods to support future-based systems --- Cargo.lock | 33 ++++++++- Cargo.toml | 20 ++++-- src/error.rs | 1 - src/fragment.rs | 4 -- src/lib.rs | 183 ++++++++++++++++++++++++++++++++---------------- 5 files changed, 170 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93add82..db07e73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,7 @@ dependencies = [ "base64", "bytes", "criterion", + "futures", "http", "http-body-util", "hyper", @@ -412,6 +413,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "tokio-util", "trybuild", "utf-8", "webpki-roots", @@ -432,6 +434,20 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -439,6 +455,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -447,6 +464,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + [[package]] name = "futures-sink" version = "0.3.30" @@ -465,10 +488,15 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1386,16 +1414,15 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" dependencies = [ "bytes", "futures-core", "futures-sink", "pin-project-lite", "tokio", - "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 5622f3f..ea64e96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,11 +28,15 @@ path = "examples/axum.rs" required-features = ["upgrade", "with_axum"] [dependencies] -tokio = { version = "1.25.0", default-features = false, features = ["io-util"] } +tokio = { version = "1.25.0", default-features = false, features = ["io-util"] } simdutf8 = { version = "0.1.4", optional = true } hyper-util = { version = "0.1.0", features = ["tokio"], optional = true } http-body-util = { version = "0.1.0", optional = true } -hyper = { version = "1", features = ["http1", "server", "client"], optional = true } +hyper = { version = "1", features = [ + "http1", + "server", + "client", +], optional = true } pin-project = { version = "1.0.8", optional = true } base64 = { version = "0.21.0", optional = true } sha1 = { version = "0.10.5", optional = true } @@ -45,12 +49,20 @@ bytes = "1.5.0" axum-core = { version = "0.4.3", optional = true } http = { version = "1", optional = true } async-trait = { version = "0.1", optional = true } +tokio-util = { version = "0.7.11", features = ["codec", "io"] } +futures = { version = "0.3.30", default-features = false, features = ["std"] } [features] default = ["simd"] simd = ["simdutf8/aarch64_neon"] -upgrade = ["hyper", "pin-project", "base64", "sha1", "hyper-util", "http-body-util"] -unstable-split = [] +upgrade = [ + "hyper", + "pin-project", + "base64", + "sha1", + "hyper-util", + "http-body-util", +] # Axum integration with_axum = ["axum-core", "http", "async-trait"] diff --git a/src/error.rs b/src/error.rs index 848116a..b8e08d5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,7 +41,6 @@ pub enum WebSocketError { #[cfg(feature = "upgrade")] #[error(transparent)] HTTPError(#[from] hyper::Error), - #[cfg(feature = "unstable-split")] #[error("Failed to send frame")] SendError(#[from] Box), } diff --git a/src/fragment.rs b/src/fragment.rs index 091fb03..9a1932a 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(feature = "unstable-split")] use std::future::Future; use crate::error::WebSocketError; @@ -20,7 +19,6 @@ use crate::frame::Frame; use crate::OpCode; use crate::ReadHalf; use crate::WebSocket; -#[cfg(feature = "unstable-split")] use crate::WebSocketRead; use crate::WriteHalf; use tokio::io::AsyncRead; @@ -137,14 +135,12 @@ impl<'f, S> FragmentCollector { } } -#[cfg(feature = "unstable-split")] pub struct FragmentCollectorRead { stream: S, read_half: ReadHalf, fragments: Fragments, } -#[cfg(feature = "unstable-split")] impl<'f, S> FragmentCollectorRead { /// Creates a new `FragmentCollector` with the provided `WebSocket`. pub fn new(ws: WebSocketRead) -> FragmentCollectorRead diff --git a/src/lib.rs b/src/lib.rs index b6d041c..b48741a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,18 +167,19 @@ pub mod upgrade; use bytes::Buf; use bytes::BytesMut; -#[cfg(feature = "unstable-split")] +use std::future::poll_fn; use std::future::Future; +use std::pin::pin; +use std::task::ready; +use std::task::Context; +use std::task::Poll; use tokio::io::AsyncRead; -use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; -use tokio::io::AsyncWriteExt; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; pub use crate::fragment::FragmentCollector; -#[cfg(feature = "unstable-split")] pub use crate::fragment::FragmentCollectorRead; pub use crate::frame::Frame; pub use crate::frame::OpCode; @@ -197,7 +198,7 @@ pub(crate) struct WriteHalf { vectored: bool, auto_apply_mask: bool, writev_threshold: usize, - write_buffer: Vec, + buffer: BytesMut, } pub(crate) struct ReadHalf { @@ -210,19 +211,16 @@ pub(crate) struct ReadHalf { buffer: BytesMut, } -#[cfg(feature = "unstable-split")] pub struct WebSocketRead { stream: S, read_half: ReadHalf, } -#[cfg(feature = "unstable-split")] pub struct WebSocketWrite { stream: S, write_half: WriteHalf, } -#[cfg(feature = "unstable-split")] /// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake. pub fn after_handshake_split( read: R, @@ -245,7 +243,6 @@ where ) } -#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketRead { /// Consumes the `WebSocketRead` and returns the underlying stream. #[inline] @@ -309,7 +306,6 @@ impl<'f, S> WebSocketRead { } } -#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketWrite { /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. /// @@ -385,7 +381,6 @@ impl<'f, S> WebSocket { /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that /// is returned. - #[cfg(feature = "unstable-split")] pub fn split( self, split_fn: impl Fn(S) -> (R, W), @@ -522,30 +517,42 @@ impl<'f, S> WebSocket { /// } /// ``` pub async fn read_frame(&mut self) -> Result, WebSocketError> + where + S: AsyncRead + AsyncWrite + Unpin, + { + poll_fn(|cx| self.poll_read_frame(cx)).await + } + + pub fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, WebSocketError>> where S: AsyncRead + AsyncWrite + Unpin, { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + ready!(self.read_half.poll_read_frame_inner(&mut self.stream, cx)); + let is_closed = self.write_half.closed; if let Some(frame) = obligated_send { if !is_closed { - self.write_half.write_frame(&mut self.stream, frame).await?; + self.write_half.start_send_frame(frame)?; + ready!(self.write_half.poll_flush(&mut self.stream, cx))?; } } + if let Some(frame) = res? { if is_closed && frame.opcode != OpCode::Close { - return Err(WebSocketError::ConnectionClosed); + return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - break Ok(frame); + + break Poll::Ready(Ok(frame)); } } } } -const MAX_HEADER_SIZE: usize = 14; - impl ReadHalf { pub fn after_handshake(role: Role) -> Self { let buffer = BytesMut::with_capacity(8192); @@ -574,9 +581,20 @@ impl ReadHalf { where S: AsyncRead + Unpin, { - let mut frame = match self.parse_frame_header(stream).await { + poll_fn(|cx| self.poll_read_frame_inner(stream, cx)).await + } + + pub(crate) fn poll_read_frame_inner<'f, S>( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll<(Result>, WebSocketError>, Option>)> + where + S: AsyncRead + Unpin, + { + let mut frame = match ready!(self.poll_parse_frame_header(stream, cx)) { Ok(frame) => frame, - Err(e) => return (Err(e), None), + Err(e) => return Poll::Ready((Err(e), None)), }; if self.role == Role::Server && self.auto_apply_mask { @@ -587,7 +605,9 @@ impl ReadHalf { OpCode::Close if self.auto_close => { match frame.payload.len() { 0 => {} - 1 => return (Err(WebSocketError::InvalidCloseFrame), None), + 1 => { + return Poll::Ready((Err(WebSocketError::InvalidCloseFrame), None)) + } _ => { let code = close::CloseCode::from(u16::from_be_bytes( frame.payload[0..2].try_into().unwrap(), @@ -595,7 +615,7 @@ impl ReadHalf { #[cfg(feature = "simd")] if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() { - return (Err(WebSocketError::InvalidUTF8), None); + return Poll::Ready((Err(WebSocketError::InvalidUTF8), None)); }; #[cfg(not(feature = "simd"))] @@ -604,49 +624,55 @@ impl ReadHalf { }; if !code.is_allowed() { - return ( + return Poll::Ready(( Err(WebSocketError::InvalidCloseCode), Some(Frame::close(1002, &frame.payload[2..])), - ); + )); } } }; let obligated_send = Frame::close_raw(frame.payload.to_owned().into()); - (Ok(Some(frame)), Some(obligated_send)) + Poll::Ready((Ok(Some(frame)), Some(obligated_send))) } OpCode::Ping if self.auto_pong => { - (Ok(None), Some(Frame::pong(frame.payload))) + Poll::Ready((Ok(None), Some(Frame::pong(frame.payload)))) } OpCode::Text => { if frame.fin && !frame.is_utf8() { - (Err(WebSocketError::InvalidUTF8), None) + Poll::Ready((Err(WebSocketError::InvalidUTF8), None)) } else { - (Ok(Some(frame)), None) + Poll::Ready((Ok(Some(frame)), None)) } } - _ => (Ok(Some(frame)), None), + _ => Poll::Ready((Ok(Some(frame)), None)), } } - async fn parse_frame_header<'a, S>( + fn poll_parse_frame_header<'a, S>( &mut self, stream: &mut S, - ) -> Result, WebSocketError> + cx: &mut Context<'_>, + ) -> Poll, WebSocketError>> where S: AsyncRead + Unpin, { - macro_rules! eof { - ($n:expr) => {{ - if $n == 0 { - return Err(WebSocketError::UnexpectedEOF); + macro_rules! read_next { + () => {{ + let bytes_read = ready!(tokio_util::io::poll_read_buf( + pin!(&mut *stream), + cx, + &mut self.buffer + ))?; + if bytes_read == 0 { + return Poll::Ready(Err(WebSocketError::UnexpectedEOF)); } }}; } // Read the first two bytes while self.buffer.remaining() < 2 { - eof!(stream.read_buf(&mut self.buffer).await?); + read_next!(); } let fin = self.buffer[0] & 0b10000000 != 0; @@ -655,7 +681,7 @@ impl ReadHalf { let rsv3 = self.buffer[0] & 0b00010000 != 0; if rsv1 || rsv2 || rsv3 { - return Err(WebSocketError::ReservedBitsNotZero); + return Poll::Ready(Err(WebSocketError::ReservedBitsNotZero)); } let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?; @@ -668,23 +694,26 @@ impl ReadHalf { _ => 0, }; - self.buffer.advance(2); - while self.buffer.remaining() < extra + masked as usize * 4 { - eof!(stream.read_buf(&mut self.buffer).await?); + // total header size + let header_size = 2 + extra + masked as usize * 4; + while self.buffer.remaining() < header_size { + read_next!(); } + let mut header = &self.buffer[2..header_size]; + let payload_len: usize = match extra { 0 => usize::from(length_code), - 2 => self.buffer.get_u16() as usize, + 2 => header.get_u16() as usize, #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] - 8 => self.buffer.get_u64() as usize, + 8 => header.get_u64() as usize, // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing #[cfg(any( target_pointer_width = "8", target_pointer_width = "16", target_pointer_width = "32" ))] - 8 => match usize::try_from(self.buffer.get_u64()) { + 8 => match usize::try_from(header.get_u64()) { Ok(length) => length, Err(_) => return Err(WebSocketError::FrameTooLarge), }, @@ -692,33 +721,33 @@ impl ReadHalf { }; let mask = if masked { - Some(self.buffer.get_u32().to_be_bytes()) + Some(header.get_u32().to_be_bytes()) } else { None }; if frame::is_control(opcode) && !fin { - return Err(WebSocketError::ControlFrameFragmented); + return Poll::Ready(Err(WebSocketError::ControlFrameFragmented)); } if opcode == OpCode::Ping && payload_len > 125 { - return Err(WebSocketError::PingFrameTooLarge); + return Poll::Ready(Err(WebSocketError::PingFrameTooLarge)); } if payload_len >= self.max_message_size { - return Err(WebSocketError::FrameTooLarge); + return Poll::Ready(Err(WebSocketError::FrameTooLarge)); } // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time - self.buffer.reserve(payload_len + MAX_HEADER_SIZE); - while payload_len > self.buffer.remaining() { - eof!(stream.read_buf(&mut self.buffer).await?); + while header_size + payload_len > self.buffer.remaining() { + read_next!(); } // if we read too much it will stay in the buffer, for the next call to this method - let payload = self.buffer.split_to(payload_len); + let mut message = self.buffer.split_to(payload_len + header_size); + let payload = message.split_off(header_size); let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload)); - Ok(frame) + Poll::Ready(Ok(frame)) } } @@ -730,7 +759,7 @@ impl WriteHalf { auto_apply_mask: true, vectored: true, writev_threshold: 1024, - write_buffer: Vec::with_capacity(2), + buffer: BytesMut::with_capacity(1024), } } @@ -738,11 +767,22 @@ impl WriteHalf { pub async fn write_frame<'a, S>( &'a mut self, stream: &mut S, - mut frame: Frame<'a>, + frame: Frame<'a>, ) -> Result<(), WebSocketError> where S: AsyncWrite + Unpin, { + self.start_send_frame(frame)?; + poll_fn(|cx| self.poll_flush(stream, cx)).await + } + + /// Writes a frame to the provided stream. + pub fn start_send_frame<'a>( + &'a mut self, + mut frame: Frame<'a>, + ) -> Result<(), WebSocketError> { + // TODO: backpressure check? + if self.role == Role::Client && self.auto_apply_mask { frame.mask(); } @@ -753,15 +793,40 @@ impl WriteHalf { return Err(WebSocketError::ConnectionClosed); } - if self.vectored && frame.payload.len() > self.writev_threshold { - frame.writev(stream).await?; - } else { - let text = frame.write(&mut self.write_buffer); - stream.write_all(text).await?; - } + frame.fmt_head(&mut self.buffer); + self.buffer.extend_from_slice(&frame.payload); Ok(()) } + + pub fn poll_flush<'a, S>( + &'a mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + loop { + // try to flush before writing + if let Err(err) = + ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into) + { + break Poll::Ready(Err(err)); + } + + if self.buffer.is_empty() { + break Poll::Ready(Ok(())); + } + + let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; + if written == 0 { + break Poll::Ready(Err(WebSocketError::ConnectionClosed)); + } + + self.buffer.advance(written); + } + } } #[cfg(test)] From 7fc161776b1fd5da8cd6f08c9927959bbcc62bb9 Mon Sep 17 00:00:00 2001 From: Dario Date: Tue, 28 May 2024 16:31:41 +0200 Subject: [PATCH 03/23] poll methods for WebSocketStream --- src/lib.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b48741a..f264caa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,9 +167,11 @@ pub mod upgrade; use bytes::Buf; use bytes::BytesMut; +use futures::task::AtomicWaker; use std::future::poll_fn; use std::future::Future; use std::pin::pin; +use std::sync::Arc; use std::task::ready; use std::task::Context; use std::task::Poll; @@ -186,6 +188,53 @@ pub use crate::frame::OpCode; pub use crate::frame::Payload; pub use crate::mask::unmask; +enum ContextKind { + /// Read is used when the cx is called from WebSocketRead. + Read, + /// Write is used when the cx is called from WebSocketWrite. + Write, +} + +// WakerDemux keeps 2 wakers, one for WebSocketRead and another for WebSocketWrite. +// +// Waking up the WakerDemux will wake both wakers. +#[derive(Default)] +struct WakerDemux { + read_waker: AtomicWaker, + write_waker: AtomicWaker, +} + +impl futures::task::ArcWake for WakerDemux { + fn wake_by_ref(this: &Arc) { + this.read_waker.wake(); + this.write_waker.wake(); + } +} + +impl WakerDemux { + /// Set the Waker to the corresponding slot. + #[inline] + fn set_waker(&self, kind: ContextKind, waker: &futures::task::Waker) { + match kind { + ContextKind::Read => { + self.read_waker.register(waker); + } + ContextKind::Write => { + self.write_waker.register(waker); + } + } + } + + fn with_context(self: &Arc, f: F) -> R + where + F: FnOnce(&mut Context<'_>) -> R, + { + let waker = futures::task::waker_ref(&self); + let mut cx = Context::from_waker(&waker); + f(&mut cx) + } +} + #[derive(Copy, Clone, PartialEq)] pub enum Role { Server, @@ -345,6 +394,7 @@ pub struct WebSocket { stream: S, write_half: WriteHalf, read_half: ReadHalf, + waker: Arc, } impl<'f, S> WebSocket { @@ -371,8 +421,10 @@ impl<'f, S> WebSocket { where S: AsyncRead + AsyncWrite + Unpin, { + let waker = Arc::new(WakerDemux::default()); Self { stream, + waker, write_half: WriteHalf::after_handshake(role), read_half: ReadHalf::after_handshake(role), } @@ -490,6 +542,31 @@ impl<'f, S> WebSocket { Ok(()) } + pub fn poll_write_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame<'f>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame)?; + self.poll_flush(cx) + } + + pub fn poll_flush( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.waker.set_waker(ContextKind::Write, cx.waker()); + self + .waker + .with_context(|cx| self.write_half.poll_flush(&mut self.stream, cx)) + } + /// Reads a frame from the stream. /// /// This method will unmask the frame payload. For fragmented frames, use `FragmentCollector::read_frame`. @@ -523,6 +600,7 @@ impl<'f, S> WebSocket { poll_fn(|cx| self.poll_read_frame(cx)).await } + /// Polls the next frame from the Stream. pub fn poll_read_frame( &mut self, cx: &mut Context<'_>, @@ -538,7 +616,10 @@ impl<'f, S> WebSocket { if let Some(frame) = obligated_send { if !is_closed { self.write_half.start_send_frame(frame)?; - ready!(self.write_half.poll_flush(&mut self.stream, cx))?; + let res = self.waker.with_context(|cx| { + self.write_half.poll_flush(&mut self.stream, cx) + }); + ready!(res)?; } } @@ -584,6 +665,7 @@ impl ReadHalf { poll_fn(|cx| self.poll_read_frame_inner(stream, cx)).await } + /// Reads a frame from the Stream. pub(crate) fn poll_read_frame_inner<'f, S>( &mut self, stream: &mut S, @@ -649,6 +731,7 @@ impl ReadHalf { } } + /// Reads a frame from the Stream parsing the headers. fn poll_parse_frame_header<'a, S>( &mut self, stream: &mut S, @@ -745,6 +828,7 @@ impl ReadHalf { // if we read too much it will stay in the buffer, for the next call to this method let mut message = self.buffer.split_to(payload_len + header_size); + // split the message off of header_size to get the payload. let payload = message.split_off(header_size); let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload)); Poll::Ready(Ok(frame)) @@ -776,7 +860,9 @@ impl WriteHalf { poll_fn(|cx| self.poll_flush(stream, cx)).await } - /// Writes a frame to the provided stream. + /// Serializes a frame into the internal buffer. + /// + /// This method is similar to [Sink::start_send](https://docs.rs/futures/latest/futures/sink/trait.Sink.html#tymethod.start_send). pub fn start_send_frame<'a>( &'a mut self, mut frame: Frame<'a>, @@ -799,6 +885,9 @@ impl WriteHalf { Ok(()) } + /// Flushes the internal buffer into the Stream. + /// + /// Returns Poll::Ready(Ok(())) when no more bytes are left. pub fn poll_flush<'a, S>( &'a mut self, stream: &mut S, From f1696f1027c1e718cb752f256a9bea3bdfcd5a08 Mon Sep 17 00:00:00 2001 From: Dario Date: Wed, 29 May 2024 11:15:27 +0200 Subject: [PATCH 04/23] WebSocket: start_send_frame function --- src/lib.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index f264caa..d9d9e5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -542,6 +542,16 @@ impl<'f, S> WebSocket { Ok(()) } + pub fn start_send_frame( + &mut self, + frame: Frame<'f>, + ) -> Result<(), WebSocketError> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame) + } + pub fn poll_write_frame( &mut self, cx: &mut Context<'_>, From 30c158638f91b0bff9c80e8a62f3a4eb8904419e Mon Sep 17 00:00:00 2001 From: Dario Date: Wed, 29 May 2024 12:05:13 +0200 Subject: [PATCH 05/23] Cargo.toml: futures version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 2955dc6..51ba3d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ axum-core = { version = "0.4.3", optional = true } http = { version = "1", optional = true } async-trait = { version = "0.1", optional = true } tokio-util = { version = "0.7.11", features = ["codec", "io"] } -futures = { version = "0.3.30", default-features = false, features = ["std"] } +futures = { version = "0.3", default-features = false, features = ["std"] } [features] default = ["simd"] From 909cd67e6cd81295f7702c1a4c44e417d5898b1d Mon Sep 17 00:00:00 2001 From: Dario Date: Wed, 29 May 2024 12:07:15 +0200 Subject: [PATCH 06/23] Cargo.toml: tokio_util version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 51ba3d8..d73e76e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ bytes = "1.5.0" axum-core = { version = "0.4.3", optional = true } http = { version = "1", optional = true } async-trait = { version = "0.1", optional = true } -tokio-util = { version = "0.7.11", features = ["codec", "io"] } +tokio-util = { version = "0.7", features = ["codec", "io"] } futures = { version = "0.3", default-features = false, features = ["std"] } [features] From 78f81ddf0abd249cceadaac2872cf34f5187bc73 Mon Sep 17 00:00:00 2001 From: Dario Date: Wed, 29 May 2024 17:13:32 +0200 Subject: [PATCH 07/23] Set the waker before flushing --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index d9d9e5d..8780cbd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -626,6 +626,7 @@ impl<'f, S> WebSocket { if let Some(frame) = obligated_send { if !is_closed { self.write_half.start_send_frame(frame)?; + self.waker.set_waker(ContextKind::Read, cx.waker()); let res = self.waker.with_context(|cx| { self.write_half.poll_flush(&mut self.stream, cx) }); From 23a7101a72d57000e13c9798ca5e5c30bcef6e1c Mon Sep 17 00:00:00 2001 From: Dario Date: Wed, 29 May 2024 17:54:49 +0200 Subject: [PATCH 08/23] Implement poll_read_frame & poll_write_frame for WebSocketRead and WebSocketWrite --- src/lib.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 8780cbd..08f07fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -353,6 +353,17 @@ impl<'f, S> WebSocketRead { } } } + + /// Reads a frame from the stream. + pub fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<(Result, WebSocketError>, Option)> + where + S: AsyncRead + Unpin, + { + self.read_half.poll_read_frame_inner(&mut self.stream, cx) + } } impl<'f, S> WebSocketWrite { @@ -387,6 +398,18 @@ impl<'f, S> WebSocketWrite { { self.write_half.write_frame(&mut self.stream, frame).await } + + pub fn poll_write_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame<'f>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame)?; + self.write_half.poll_flush(&mut self.stream, cx) + } } /// WebSocket protocol implementation over an async stream. From 59904ada70d81716a8d9fbc539ee6cc778efa7d9 Mon Sep 17 00:00:00 2001 From: Dario Date: Sun, 23 Jun 2024 16:45:45 +0200 Subject: [PATCH 09/23] Bring back unstable-split feature --- Cargo.toml | 1 + src/error.rs | 1 + src/fragment.rs | 4 ++++ src/lib.rs | 10 +++++++++- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d73e76e..e32eaf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ upgrade = [ "hyper-util", "http-body-util", ] +unstable-split = [] # Axum integration with_axum = ["axum-core", "http", "async-trait"] diff --git a/src/error.rs b/src/error.rs index b8e08d5..848116a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,7 @@ pub enum WebSocketError { #[cfg(feature = "upgrade")] #[error(transparent)] HTTPError(#[from] hyper::Error), + #[cfg(feature = "unstable-split")] #[error("Failed to send frame")] SendError(#[from] Box), } diff --git a/src/fragment.rs b/src/fragment.rs index 9a1932a..091fb03 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature = "unstable-split")] use std::future::Future; use crate::error::WebSocketError; @@ -19,6 +20,7 @@ use crate::frame::Frame; use crate::OpCode; use crate::ReadHalf; use crate::WebSocket; +#[cfg(feature = "unstable-split")] use crate::WebSocketRead; use crate::WriteHalf; use tokio::io::AsyncRead; @@ -135,12 +137,14 @@ impl<'f, S> FragmentCollector { } } +#[cfg(feature = "unstable-split")] pub struct FragmentCollectorRead { stream: S, read_half: ReadHalf, fragments: Fragments, } +#[cfg(feature = "unstable-split")] impl<'f, S> FragmentCollectorRead { /// Creates a new `FragmentCollector` with the provided `WebSocket`. pub fn new(ws: WebSocketRead) -> FragmentCollectorRead diff --git a/src/lib.rs b/src/lib.rs index 08f07fa..d085e77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,19 +169,22 @@ use bytes::Buf; use bytes::BytesMut; use futures::task::AtomicWaker; use std::future::poll_fn; -use std::future::Future; use std::pin::pin; use std::sync::Arc; use std::task::ready; use std::task::Context; use std::task::Poll; +#[cfg(feature = "unstable-split")] +use std::future::Future; + use tokio::io::AsyncRead; use tokio::io::AsyncWrite; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; pub use crate::fragment::FragmentCollector; +#[cfg(feature = "unstable-split")] pub use crate::fragment::FragmentCollectorRead; pub use crate::frame::Frame; pub use crate::frame::OpCode; @@ -260,6 +263,7 @@ pub(crate) struct ReadHalf { buffer: BytesMut, } +#[cfg(feature = "unstable-split")] pub struct WebSocketRead { stream: S, read_half: ReadHalf, @@ -270,6 +274,7 @@ pub struct WebSocketWrite { write_half: WriteHalf, } +#[cfg(feature = "unstable-split")] /// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake. pub fn after_handshake_split( read: R, @@ -292,6 +297,7 @@ where ) } +#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketRead { /// Consumes the `WebSocketRead` and returns the underlying stream. #[inline] @@ -366,6 +372,7 @@ impl<'f, S> WebSocketRead { } } +#[cfg(feature = "unstable-split")] impl<'f, S> WebSocketWrite { /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. /// @@ -453,6 +460,7 @@ impl<'f, S> WebSocket { } } + #[cfg(feature = "unstable-split")] /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that /// is returned. From 8285b556f25338c98334dc22853709988d905894 Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 24 Jun 2024 08:45:13 +0200 Subject: [PATCH 10/23] write_frame: Check readiness of the underlying connection by flushing some buffered data --- src/lib.rs | 51 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d085e77..df8f4ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -898,12 +898,57 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - self.start_send_frame(frame)?; - poll_fn(|cx| self.poll_flush(stream, cx)).await + // maybe_frame determines the state. + // If a frame is present we need to poll_ready, else flush it. + let mut maybe_frame = Some(frame); + poll_fn(|cx| loop { + match maybe_frame.take() { + Some(frame) => match self.poll_ready(stream, cx) { + Poll::Ready(res) => { + res?; + self.start_send_frame(frame)?; + } + Poll::Pending => { + maybe_frame = Some(frame); + return Poll::Pending; + } + }, + None => { + return self.poll_flush(stream, cx); + } + } + }) + .await + } + + /// Ensures that the underlying connection is ready. It will try to flush the contents if any. + /// + /// If you prefer to buffer requests as much as possible you can skip this step, generally and + /// call start_send_frame. + pub fn poll_ready( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + if self.buffer.is_empty() { + Poll::Ready(Ok(())) + } else { + let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; + if written == 0 { + Poll::Ready(Err(WebSocketError::ConnectionClosed)) + } else { + Poll::Ready(Ok(())) + } + } } /// Serializes a frame into the internal buffer. /// + /// Beware of the internal buffer. If the other end of the connection is not consuming fast enough it might fill fast. + /// /// This method is similar to [Sink::start_send](https://docs.rs/futures/latest/futures/sink/trait.Sink.html#tymethod.start_send). pub fn start_send_frame<'a>( &'a mut self, @@ -921,6 +966,8 @@ impl WriteHalf { return Err(WebSocketError::ConnectionClosed); } + // TODO(dgrr): Cap max payload size with a user setting? + frame.fmt_head(&mut self.buffer); self.buffer.extend_from_slice(&frame.payload); From 3f050bf64fb91cdbfd015fa071e24ca519dc7846 Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 24 Jun 2024 16:51:51 +0200 Subject: [PATCH 11/23] Function docs --- src/lib.rs | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index df8f4ff..d67e378 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,9 +198,13 @@ enum ContextKind { Write, } -// WakerDemux keeps 2 wakers, one for WebSocketRead and another for WebSocketWrite. +// WakerDemux keeps track of whether the waker was called from a reader or a writer. // -// Waking up the WakerDemux will wake both wakers. +// This is important because the reader can also write, in order to reply to Ping or Close messages. +// If we didn't implement the WakerDemux the reader could hijack the writer's Waker and the writer's task +// would never get notified. +// +// Waking up the WakerDemux will wake the read and write tasks. #[derive(Default)] struct WakerDemux { read_waker: AtomicWaker, @@ -228,6 +232,7 @@ impl WakerDemux { } } + #[inline] fn with_context(self: &Arc, f: F) -> R where F: FnOnce(&mut Context<'_>) -> R, @@ -238,6 +243,10 @@ impl WakerDemux { } } +/// The role the connection is taking. +/// +/// When a server role is taken the frames will not be masked, unlike +/// the client role, in which frames are masked. #[derive(Copy, Clone, PartialEq)] pub enum Role { Server, @@ -264,11 +273,14 @@ pub(crate) struct ReadHalf { } #[cfg(feature = "unstable-split")] +/// Read end of a WebSocket connection. pub struct WebSocketRead { stream: S, read_half: ReadHalf, } +#[cfg(feature = "unstable-split")] +/// Write end of a WebSocket connection. pub struct WebSocketWrite { stream: S, write_half: WriteHalf, @@ -392,10 +404,12 @@ impl<'f, S> WebSocketWrite { self.write_half.auto_apply_mask = auto_apply_mask; } + /// Returns whether the connection was closed or not. pub fn is_closed(&self) -> bool { self.write_half.closed } + /// Sends a frame. pub async fn write_frame( &mut self, frame: Frame<'f>, @@ -406,6 +420,9 @@ impl<'f, S> WebSocketWrite { self.write_half.write_frame(&mut self.stream, frame).await } + /// Serializes the frame into the internal buffer and tries to flush the contents. + /// + /// If the function returns Poll::Pending, the user needs to call poll_flush. pub fn poll_write_frame( &mut self, cx: &mut Context<'_>, @@ -541,6 +558,7 @@ impl<'f, S> WebSocket { self.write_half.auto_apply_mask = auto_apply_mask; } + /// Returns whether the connection is closed or not. pub fn is_closed(&self) -> bool { self.write_half.closed } @@ -573,6 +591,9 @@ impl<'f, S> WebSocket { Ok(()) } + /// Serializes a frame into the internal buffer. + /// + /// This method is similar to [Sink::start_send](https://docs.rs/futures/0.3.30/futures/sink/trait.Sink.html#tymethod.start_send). pub fn start_send_frame( &mut self, frame: Frame<'f>, @@ -583,6 +604,11 @@ impl<'f, S> WebSocket { self.write_half.start_send_frame(frame) } + /// Serializes a frame into the internal buffer. + /// + /// Beware of the internal buffer. If the other end of the connection is not consuming fast enough it might fill fast. + /// + /// This method is similar to [Sink::start_send](https://docs.rs/futures/0.3.30/futures/sink/trait.Sink.html#tymethod.start_send). pub fn poll_write_frame( &mut self, cx: &mut Context<'_>, @@ -595,6 +621,9 @@ impl<'f, S> WebSocket { self.poll_flush(cx) } + /// Flushes the internal buffer into the Stream. + /// + /// Returns Poll::Ready(Ok(())) when no more bytes are left. pub fn poll_flush( &mut self, cx: &mut Context<'_>, @@ -945,11 +974,6 @@ impl WriteHalf { } } - /// Serializes a frame into the internal buffer. - /// - /// Beware of the internal buffer. If the other end of the connection is not consuming fast enough it might fill fast. - /// - /// This method is similar to [Sink::start_send](https://docs.rs/futures/latest/futures/sink/trait.Sink.html#tymethod.start_send). pub fn start_send_frame<'a>( &'a mut self, mut frame: Frame<'a>, @@ -974,9 +998,6 @@ impl WriteHalf { Ok(()) } - /// Flushes the internal buffer into the Stream. - /// - /// Returns Poll::Ready(Ok(())) when no more bytes are left. pub fn poll_flush<'a, S>( &'a mut self, stream: &mut S, From b7b769641d144267d4588019850e2991f428d11f Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 24 Jun 2024 16:52:52 +0200 Subject: [PATCH 12/23] Fix return type when not using SIMD for utf8 processing --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index d67e378..a794ddc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -773,7 +773,7 @@ impl ReadHalf { #[cfg(not(feature = "simd"))] if std::str::from_utf8(&frame.payload[2..]).is_err() { - return (Err(WebSocketError::InvalidUTF8), None); + return Poll::Ready((Err(WebSocketError::InvalidUTF8), None)); }; if !code.is_allowed() { From b1b087cf1e0c669b336d70ba1e5e4bc1ebb32609 Mon Sep 17 00:00:00 2001 From: Dario Date: Sun, 30 Jun 2024 16:22:34 +0200 Subject: [PATCH 13/23] Generate state machine for reading --- src/lib.rs | 221 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 147 insertions(+), 74 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a794ddc..2768da9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,9 +269,30 @@ pub(crate) struct ReadHalf { auto_pong: bool, writev_threshold: usize, max_message_size: usize, + read_state: Option, buffer: BytesMut, } +struct Header { + fin: bool, + masked: bool, + opcode: OpCode, + extra: usize, + length_code: u8, + header_size: usize, +} + +struct HeaderAndMask { + header: Header, + mask: Option<[u8; 4]>, + payload_len: usize, +} + +enum ReadState { + Header(Header), + Payload(HeaderAndMask), +} + #[cfg(feature = "unstable-split")] /// Read end of a WebSocket connection. pub struct WebSocketRead { @@ -373,6 +394,7 @@ impl<'f, S> WebSocketRead { } /// Reads a frame from the stream. + #[inline(always)] pub fn poll_read_frame( &mut self, cx: &mut Context<'_>, @@ -609,6 +631,7 @@ impl<'f, S> WebSocket { /// Beware of the internal buffer. If the other end of the connection is not consuming fast enough it might fill fast. /// /// This method is similar to [Sink::start_send](https://docs.rs/futures/0.3.30/futures/sink/trait.Sink.html#tymethod.start_send). + #[inline(always)] pub fn poll_write_frame( &mut self, cx: &mut Context<'_>, @@ -624,6 +647,7 @@ impl<'f, S> WebSocket { /// Flushes the internal buffer into the Stream. /// /// Returns Poll::Ready(Ok(())) when no more bytes are left. + #[inline(always)] pub fn poll_flush( &mut self, cx: &mut Context<'_>, @@ -663,6 +687,7 @@ impl<'f, S> WebSocket { /// Ok(()) /// } /// ``` + #[inline(always)] pub async fn read_frame(&mut self) -> Result, WebSocketError> where S: AsyncRead + AsyncWrite + Unpin, @@ -711,6 +736,7 @@ impl ReadHalf { Self { role, + read_state: None, auto_apply_mask: true, auto_close: true, auto_pong: true, @@ -726,6 +752,7 @@ impl ReadHalf { /// has been closed. /// /// XXX: Do not expose this method to the public API. + #[inline(always)] pub(crate) async fn read_frame_inner<'f, S>( &mut self, stream: &mut S, @@ -812,97 +839,143 @@ impl ReadHalf { S: AsyncRead + Unpin, { macro_rules! read_next { - () => {{ - let bytes_read = ready!(tokio_util::io::poll_read_buf( + ($variant:expr,$value:expr) => {{ + let bytes_read = match tokio_util::io::poll_read_buf( pin!(&mut *stream), cx, - &mut self.buffer - ))?; + &mut self.buffer, + ) { + Poll::Ready(ready) => ready, + Poll::Pending => { + self.read_state = Some($variant($value)); + return Poll::Pending; + } + }?; if bytes_read == 0 { return Poll::Ready(Err(WebSocketError::UnexpectedEOF)); } }}; } - // Read the first two bytes - while self.buffer.remaining() < 2 { - read_next!(); - } - - let fin = self.buffer[0] & 0b10000000 != 0; - let rsv1 = self.buffer[0] & 0b01000000 != 0; - let rsv2 = self.buffer[0] & 0b00100000 != 0; - let rsv3 = self.buffer[0] & 0b00010000 != 0; - - if rsv1 || rsv2 || rsv3 { - return Poll::Ready(Err(WebSocketError::ReservedBitsNotZero)); - } - - let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?; - let masked = self.buffer[1] & 0b10000000 != 0; + loop { + match self.read_state.take() { + None => { + // Read the first two bytes + while self.buffer.remaining() < 2 { + let bytes_read = ready!(tokio_util::io::poll_read_buf( + pin!(&mut *stream), + cx, + &mut self.buffer + ))?; + if bytes_read == 0 { + return Poll::Ready(Err(WebSocketError::UnexpectedEOF)); + } + } - let length_code = self.buffer[1] & 0x7F; - let extra = match length_code { - 126 => 2, - 127 => 8, - _ => 0, - }; + let fin = self.buffer[0] & 0b10000000 != 0; + let rsv1 = self.buffer[0] & 0b01000000 != 0; + let rsv2 = self.buffer[0] & 0b00100000 != 0; + let rsv3 = self.buffer[0] & 0b00010000 != 0; - // total header size - let header_size = 2 + extra + masked as usize * 4; - while self.buffer.remaining() < header_size { - read_next!(); - } + if rsv1 || rsv2 || rsv3 { + return Poll::Ready(Err(WebSocketError::ReservedBitsNotZero)); + } - let mut header = &self.buffer[2..header_size]; - - let payload_len: usize = match extra { - 0 => usize::from(length_code), - 2 => header.get_u16() as usize, - #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] - 8 => header.get_u64() as usize, - // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing - #[cfg(any( - target_pointer_width = "8", - target_pointer_width = "16", - target_pointer_width = "32" - ))] - 8 => match usize::try_from(header.get_u64()) { - Ok(length) => length, - Err(_) => return Err(WebSocketError::FrameTooLarge), - }, - _ => unreachable!(), - }; + let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?; + let masked = self.buffer[1] & 0b10000000 != 0; + + let length_code = self.buffer[1] & 0x7F; + let extra = match length_code { + 126 => 2, + 127 => 8, + _ => 0, + }; + + let header_size = extra + masked as usize * 4; + self.buffer.advance(2); + + self.read_state = Some(ReadState::Header(Header { + fin, + masked, + opcode, + length_code, + extra, + header_size, + })); + } + Some(ReadState::Header(header)) => { + // total header size + while self.buffer.remaining() < header.header_size { + read_next!(ReadState::Header, header); + } - let mask = if masked { - Some(header.get_u32().to_be_bytes()) - } else { - None - }; + let payload_len: usize = match header.extra { + 0 => usize::from(header.length_code), + 2 => self.buffer.get_u16() as usize, + #[cfg(any( + target_pointer_width = "64", + target_pointer_width = "128" + ))] + 8 => self.buffer.get_u64() as usize, + // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32" + ))] + 8 => match usize::try_from(self.buffer.get_u64()) { + Ok(length) => length, + Err(_) => return Err(WebSocketError::FrameTooLarge), + }, + _ => unreachable!(), + }; + + let mask = if header.masked { + Some(self.buffer.get_u32().to_be_bytes()) + } else { + None + }; + + if frame::is_control(header.opcode) && !header.fin { + return Poll::Ready(Err(WebSocketError::ControlFrameFragmented)); + } - if frame::is_control(opcode) && !fin { - return Poll::Ready(Err(WebSocketError::ControlFrameFragmented)); - } + if header.opcode == OpCode::Ping && payload_len > 125 { + return Poll::Ready(Err(WebSocketError::PingFrameTooLarge)); + } - if opcode == OpCode::Ping && payload_len > 125 { - return Poll::Ready(Err(WebSocketError::PingFrameTooLarge)); - } + if payload_len >= self.max_message_size { + return Poll::Ready(Err(WebSocketError::FrameTooLarge)); + } - if payload_len >= self.max_message_size { - return Poll::Ready(Err(WebSocketError::FrameTooLarge)); - } + self.read_state = Some(ReadState::Payload(HeaderAndMask { + header, + mask, + payload_len, + })); + } + Some(ReadState::Payload(header_and_mask)) => { + // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time + self.buffer.reserve(header_and_mask.payload_len + 14); + while self.buffer.remaining() < header_and_mask.payload_len { + read_next!(ReadState::Payload, header_and_mask); + } - // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time - while header_size + payload_len > self.buffer.remaining() { - read_next!(); + let header = header_and_mask.header; + let mask = header_and_mask.mask; + let payload_len = header_and_mask.payload_len; + + let payload = self.buffer.split_to(payload_len); + let frame = Frame::new( + header.fin, + header.opcode, + mask, + Payload::Bytes(payload), + ); + break Poll::Ready(Ok(frame)); + } + } } - - // if we read too much it will stay in the buffer, for the next call to this method - let mut message = self.buffer.split_to(payload_len + header_size); - // split the message off of header_size to get the payload. - let payload = message.split_off(header_size); - let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload)); - Poll::Ready(Ok(frame)) } } From 3636e5b035de6016d42646240f69868ff69d6083 Mon Sep 17 00:00:00 2001 From: Dario Date: Sun, 30 Jun 2024 21:49:21 +0200 Subject: [PATCH 14/23] WriteHalf: Advance buffer on poll_ready --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index 2768da9..08e11dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1042,6 +1042,7 @@ impl WriteHalf { if written == 0 { Poll::Ready(Err(WebSocketError::ConnectionClosed)) } else { + self.buffer.advance(written); Poll::Ready(Ok(())) } } From 3a537e3923e9162605f749d2739503cf78f58a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dar=C3=ADo?= Date: Sun, 30 Jun 2024 22:03:08 +0200 Subject: [PATCH 15/23] Simplify loop by conradludgate Co-authored-by: Conrad Ludgate --- src/lib.rs | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 08e11dc..d699a70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1080,25 +1080,17 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - loop { - // try to flush before writing - if let Err(err) = - ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into) - { - break Poll::Ready(Err(err)); - } - - if self.buffer.is_empty() { - break Poll::Ready(Ok(())); - } - + // flush the buffer + while !self.buffer.is_empty() { let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; if written == 0 { - break Poll::Ready(Err(WebSocketError::ConnectionClosed)); + return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - self.buffer.advance(written); } + + // flush the stream + Poll::Ready(ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into)) } } From da4f6d71c9a143a168d1e7799ea9544732524465 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 1 Jul 2024 08:30:34 +0100 Subject: [PATCH 16/23] use vec --- src/lib.rs | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d699a70..920de31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -259,7 +259,8 @@ pub(crate) struct WriteHalf { vectored: bool, auto_apply_mask: bool, writev_threshold: usize, - buffer: BytesMut, + buf_pos: usize, + buffer: Vec, } pub(crate) struct ReadHalf { @@ -987,7 +988,8 @@ impl WriteHalf { auto_apply_mask: true, vectored: true, writev_threshold: 1024, - buffer: BytesMut::with_capacity(1024), + buf_pos: 0, + buffer: Vec::with_capacity(1024), } } @@ -1035,14 +1037,16 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - if self.buffer.is_empty() { + if self.buf_pos >= self.buffer.len() { Poll::Ready(Ok(())) } else { - let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; + let written = ready!( + pin!(&mut *stream).poll_write(cx, &self.buffer[self.buf_pos..]) + )?; + self.buf_pos += written; if written == 0 { Poll::Ready(Err(WebSocketError::ConnectionClosed)) } else { - self.buffer.advance(written); Poll::Ready(Ok(())) } } @@ -1066,6 +1070,8 @@ impl WriteHalf { // TODO(dgrr): Cap max payload size with a user setting? + self.buffer.splice(0..self.buf_pos, [0u8; 0]); + self.buf_pos = 0; frame.fmt_head(&mut self.buffer); self.buffer.extend_from_slice(&frame.payload); @@ -1080,13 +1086,14 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - // flush the buffer - while !self.buffer.is_empty() { - let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; + while self.buf_pos < self.buffer.len() { + let written = ready!( + pin!(&mut *stream).poll_write(cx, &self.buffer[self.buf_pos..]) + )?; if written == 0 { return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - self.buffer.advance(written); + self.buf_pos += written; } // flush the stream From 9abbf6c00aea0ae8788784a1e872f53bda0da48f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 1 Jul 2024 09:13:48 +0100 Subject: [PATCH 17/23] better buffer manipulation --- src/frame.rs | 25 ++++++++++++------------- src/lib.rs | 35 ++++++++++++++++++++++++----------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/frame.rs b/src/frame.rs index 0a34ce6..ba2676d 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -14,7 +14,7 @@ use tokio::io::AsyncWriteExt; -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use core::ops::Deref; use crate::WebSocketError; @@ -138,7 +138,7 @@ pub struct Frame<'f> { pub payload: Payload<'f>, } -const MAX_HEAD_SIZE: usize = 16; +pub(crate) const MAX_HEAD_SIZE: usize = 16; impl<'f> Frame<'f> { /// Creates a new WebSocket `Frame`. @@ -259,27 +259,26 @@ impl<'f> Frame<'f> { /// # Panics /// /// This method panics if the head buffer is not at least n-bytes long, where n is the size of the length field (0, 2, 4, or 10) - pub fn fmt_head(&mut self, mut head: impl BufMut) -> usize { - head.put_u8((self.fin as u8) << 7 | (self.opcode as u8)); - - let mask_bit = if self.mask.is_some() { 0x80 } else { 0x0 }; + pub fn fmt_head(&mut self, head: &mut [u8]) -> usize { + head[0] = (self.fin as u8) << 7 | (self.opcode as u8); let len = self.payload.len(); let size = if len < 126 { - head.put_u8(len as u8 | mask_bit); + head[1] = len as u8; 2 } else if len < 65536 { - head.put_u8(126u8 | mask_bit); - head.put_slice(&(len as u16).to_be_bytes()); + head[1] = 126; + head[2..4].copy_from_slice(&(len as u16).to_be_bytes()); 4 } else { - head.put_u8(127u8 | mask_bit); - head.put_slice(&(len as u64).to_be_bytes()); + head[1] = 127; + head[2..10].copy_from_slice(&(len as u64).to_be_bytes()); 10 }; if let Some(mask) = self.mask { - head.put_slice(&mask); + head[1] |= 0x80; + head[size..size + 4].copy_from_slice(&mask); size + 4 } else { size @@ -296,7 +295,7 @@ impl<'f> Frame<'f> { use std::io::IoSlice; let mut head = [0; MAX_HEAD_SIZE]; - let size = self.fmt_head(&mut head[..]); + let size = self.fmt_head(&mut head); let total = size + self.payload.len(); diff --git a/src/lib.rs b/src/lib.rs index 920de31..74e494f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,6 +167,7 @@ pub mod upgrade; use bytes::Buf; use bytes::BytesMut; +use frame::MAX_HEAD_SIZE; use futures::task::AtomicWaker; use std::future::poll_fn; use std::pin::pin; @@ -259,7 +260,7 @@ pub(crate) struct WriteHalf { vectored: bool, auto_apply_mask: bool, writev_threshold: usize, - buf_pos: usize, + read_head: usize, buffer: Vec, } @@ -988,7 +989,7 @@ impl WriteHalf { auto_apply_mask: true, vectored: true, writev_threshold: 1024, - buf_pos: 0, + read_head: 0, buffer: Vec::with_capacity(1024), } } @@ -1037,13 +1038,13 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - if self.buf_pos >= self.buffer.len() { + if self.read_head >= self.buffer.len() { Poll::Ready(Ok(())) } else { let written = ready!( - pin!(&mut *stream).poll_write(cx, &self.buffer[self.buf_pos..]) + pin!(&mut *stream).poll_write(cx, &self.buffer[self.read_head..]) )?; - self.buf_pos += written; + self.read_head += written; if written == 0 { Poll::Ready(Err(WebSocketError::ConnectionClosed)) } else { @@ -1070,9 +1071,21 @@ impl WriteHalf { // TODO(dgrr): Cap max payload size with a user setting? - self.buffer.splice(0..self.buf_pos, [0u8; 0]); - self.buf_pos = 0; - frame.fmt_head(&mut self.buffer); + let payload_len = frame.payload.len(); + let max_len = payload_len + MAX_HEAD_SIZE; + if self.buffer.len() + max_len > self.buffer.capacity() { + // if the len we need for this frame will require a realloc, let's clear the written head of the buffer + self.buffer.splice(0..self.read_head, [0u8; 0]); + self.read_head = 0; + self.buffer.reserve(max_len); + } + // resize the buffer so we have room to write the head + let current_len = self.buffer.len(); + self.buffer.resize(current_len + MAX_HEAD_SIZE, 0); + + let buf = &mut self.buffer[current_len..]; + let size = frame.fmt_head(buf); + self.buffer.truncate(current_len + size); self.buffer.extend_from_slice(&frame.payload); Ok(()) @@ -1086,14 +1099,14 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - while self.buf_pos < self.buffer.len() { + while self.read_head < self.buffer.len() { let written = ready!( - pin!(&mut *stream).poll_write(cx, &self.buffer[self.buf_pos..]) + pin!(&mut *stream).poll_write(cx, &self.buffer[self.read_head..]) )?; if written == 0 { return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - self.buf_pos += written; + self.read_head += written; } // flush the stream From e91e9226347021f312071306ca4dbe67d5bfd52d Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 1 Jul 2024 09:36:21 +0100 Subject: [PATCH 18/23] re-introduce vectored writes --- src/lib.rs | 117 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 97 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 74e494f..a51c08f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -166,10 +166,13 @@ pub mod upgrade; use bytes::Buf; +use bytes::Bytes; use bytes::BytesMut; use frame::MAX_HEAD_SIZE; use futures::task::AtomicWaker; +use std::collections::VecDeque; use std::future::poll_fn; +use std::io::IoSlice; use std::pin::pin; use std::sync::Arc; use std::task::ready; @@ -260,8 +263,18 @@ pub(crate) struct WriteHalf { vectored: bool, auto_apply_mask: bool, writev_threshold: usize, + write_buffer: Vec, + // where in the write_buffer we should read from when writing to the stream read_head: usize, - buffer: Vec, + // only used with vectored writes. stores the frame payloads + payloads: VecDeque, +} + +struct WriteBuffer { + // where in the write_buffer this payload should be inserted + position: usize, + read_head: usize, + payload: Payload<'static>, } pub(crate) struct ReadHalf { @@ -990,7 +1003,8 @@ impl WriteHalf { vectored: true, writev_threshold: 1024, read_head: 0, - buffer: Vec::with_capacity(1024), + write_buffer: Vec::with_capacity(1024), + payloads: VecDeque::with_capacity(1), } } @@ -1038,13 +1052,36 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - if self.read_head >= self.buffer.len() { + if self.read_head >= self.write_buffer.len() && self.payloads.is_empty() { Poll::Ready(Ok(())) } else { - let written = ready!( - pin!(&mut *stream).poll_write(cx, &self.buffer[self.read_head..]) - )?; - self.read_head += written; + let written = if let Some(front) = self.payloads.front_mut() { + let b = [ + IoSlice::new(&self.write_buffer[self.read_head..front.position]), + IoSlice::new(&front.payload), + ]; + + let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; + + if written < b[0].len() { + self.read_head += written; + } else { + let written = written - b[0].len(); + self.read_head = front.position; + front.read_head += written; + if front.read_head == front.payload.len() { + self.payloads.pop_front(); + } + } + + written + } else { + let written = ready!(pin!(&mut *stream) + .poll_write(cx, &self.write_buffer[self.read_head..]))?; + self.read_head += written; + written + }; + if written == 0 { Poll::Ready(Err(WebSocketError::ConnectionClosed)) } else { @@ -1073,20 +1110,36 @@ impl WriteHalf { let payload_len = frame.payload.len(); let max_len = payload_len + MAX_HEAD_SIZE; - if self.buffer.len() + max_len > self.buffer.capacity() { + if self.write_buffer.len() + max_len > self.write_buffer.capacity() { // if the len we need for this frame will require a realloc, let's clear the written head of the buffer - self.buffer.splice(0..self.read_head, [0u8; 0]); + self.write_buffer.splice(0..self.read_head, [0u8; 0]); self.read_head = 0; - self.buffer.reserve(max_len); + self.write_buffer.reserve(max_len); } // resize the buffer so we have room to write the head - let current_len = self.buffer.len(); - self.buffer.resize(current_len + MAX_HEAD_SIZE, 0); + let current_len = self.write_buffer.len(); + self.write_buffer.resize(current_len + MAX_HEAD_SIZE, 0); - let buf = &mut self.buffer[current_len..]; + let buf = &mut self.write_buffer[current_len..]; let size = frame.fmt_head(buf); - self.buffer.truncate(current_len + size); - self.buffer.extend_from_slice(&frame.payload); + self.write_buffer.truncate(current_len + size); + + let vectored = self.vectored && frame.payload.len() > self.writev_threshold; + match frame.payload { + Payload::Owned(b) if vectored => self.payloads.push_back(WriteBuffer { + position: self.write_buffer.len(), + read_head: 0, + payload: Payload::Owned(b), + }), + Payload::Bytes(b) if vectored => self.payloads.push_back(WriteBuffer { + position: self.write_buffer.len(), + read_head: 0, + payload: Payload::Bytes(b), + }), + _ => { + self.write_buffer.extend_from_slice(&frame.payload); + } + } Ok(()) } @@ -1099,14 +1152,38 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - while self.read_head < self.buffer.len() { - let written = ready!( - pin!(&mut *stream).poll_write(cx, &self.buffer[self.read_head..]) - )?; + while self.read_head < self.write_buffer.len() || !self.payloads.is_empty() + { + let written = if let Some(front) = self.payloads.front_mut() { + let b = [ + IoSlice::new(&self.write_buffer[self.read_head..front.position]), + IoSlice::new(&front.payload), + ]; + + let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; + + if written < b[0].len() { + self.read_head += written; + } else { + let written = written - b[0].len(); + self.read_head = front.position; + front.read_head += written; + if front.read_head == front.payload.len() { + self.payloads.pop_front(); + } + } + + written + } else { + let written = ready!(pin!(&mut *stream) + .poll_write(cx, &self.write_buffer[self.read_head..]))?; + self.read_head += written; + written + }; + if written == 0 { return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - self.read_head += written; } // flush the stream From 5d450776de6ad18c6bea848a0c8b48aead86f524 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 1 Jul 2024 09:46:14 +0100 Subject: [PATCH 19/23] refactor --- src/lib.rs | 107 ++++++++++++++++++++++------------------------------- 1 file changed, 44 insertions(+), 63 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a51c08f..027e554 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1052,42 +1052,12 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - if self.read_head >= self.write_buffer.len() && self.payloads.is_empty() { - Poll::Ready(Ok(())) - } else { - let written = if let Some(front) = self.payloads.front_mut() { - let b = [ - IoSlice::new(&self.write_buffer[self.read_head..front.position]), - IoSlice::new(&front.payload), - ]; - - let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; - - if written < b[0].len() { - self.read_head += written; - } else { - let written = written - b[0].len(); - self.read_head = front.position; - front.read_head += written; - if front.read_head == front.payload.len() { - self.payloads.pop_front(); - } - } - - written - } else { - let written = ready!(pin!(&mut *stream) - .poll_write(cx, &self.write_buffer[self.read_head..]))?; - self.read_head += written; - written - }; - - if written == 0 { - Poll::Ready(Err(WebSocketError::ConnectionClosed)) - } else { - Poll::Ready(Ok(())) - } + while self.read_head < self.write_buffer.len() || !self.payloads.is_empty() + { + ready!(self.write(stream, cx))?; } + + Poll::Ready(Ok(())) } pub fn start_send_frame<'a>( @@ -1152,42 +1122,53 @@ impl WriteHalf { where S: AsyncWrite + Unpin, { - while self.read_head < self.write_buffer.len() || !self.payloads.is_empty() - { - let written = if let Some(front) = self.payloads.front_mut() { - let b = [ - IoSlice::new(&self.write_buffer[self.read_head..front.position]), - IoSlice::new(&front.payload), - ]; + ready!(self.poll_ready(stream, cx))?; + + // flush the stream + Poll::Ready(ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into)) + } - let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; + fn write( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + let written = if let Some(front) = self.payloads.front_mut() { + let b = [ + IoSlice::new(&self.write_buffer[self.read_head..front.position]), + IoSlice::new(&front.payload), + ]; - if written < b[0].len() { - self.read_head += written; - } else { - let written = written - b[0].len(); - self.read_head = front.position; - front.read_head += written; - if front.read_head == front.payload.len() { - self.payloads.pop_front(); - } - } + let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; - written + if written < b[0].len() { + self.read_head += written; } else { - let written = ready!(pin!(&mut *stream) + let written = written - b[0].len(); + self.read_head = front.position; + front.read_head += written; + if front.read_head == front.payload.len() { + self.payloads.pop_front(); + } + } + + written + } else { + let written = + ready!(pin!(&mut *stream) .poll_write(cx, &self.write_buffer[self.read_head..]))?; - self.read_head += written; - written - }; + self.read_head += written; + written + }; - if written == 0 { - return Poll::Ready(Err(WebSocketError::ConnectionClosed)); - } + if written == 0 { + return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - // flush the stream - Poll::Ready(ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into)) + Poll::Ready(Ok(())) } } From 1d579a1c6bbe926ead68c796ae553fd1e962c5ca Mon Sep 17 00:00:00 2001 From: Dario Date: Mon, 1 Jul 2024 12:17:08 +0200 Subject: [PATCH 20/23] Use a Vec instead of BytesMut for the write buffer --- src/frame.rs | 24 +++++++++++++----------- src/lib.rs | 39 ++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/frame.rs b/src/frame.rs index 0a34ce6..aa4088e 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -14,7 +14,7 @@ use tokio::io::AsyncWriteExt; -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use core::ops::Deref; use crate::WebSocketError; @@ -259,27 +259,26 @@ impl<'f> Frame<'f> { /// # Panics /// /// This method panics if the head buffer is not at least n-bytes long, where n is the size of the length field (0, 2, 4, or 10) - pub fn fmt_head(&mut self, mut head: impl BufMut) -> usize { - head.put_u8((self.fin as u8) << 7 | (self.opcode as u8)); - - let mask_bit = if self.mask.is_some() { 0x80 } else { 0x0 }; + pub fn fmt_head(&mut self, head: &mut [u8]) -> usize { + head[0] = (self.fin as u8) << 7 | (self.opcode as u8); let len = self.payload.len(); let size = if len < 126 { - head.put_u8(len as u8 | mask_bit); + head[1] = len as u8; 2 } else if len < 65536 { - head.put_u8(126u8 | mask_bit); - head.put_slice(&(len as u16).to_be_bytes()); + head[1] = 126; + head[2..4].copy_from_slice(&(len as u16).to_be_bytes()); 4 } else { - head.put_u8(127u8 | mask_bit); - head.put_slice(&(len as u64).to_be_bytes()); + head[1] = 127; + head[2..10].copy_from_slice(&(len as u64).to_be_bytes()); 10 }; if let Some(mask) = self.mask { - head.put_slice(&mask); + head[1] |= 0x80; + head[size..size + 4].copy_from_slice(&mask); size + 4 } else { size @@ -322,6 +321,9 @@ impl<'f> Frame<'f> { } /// Writes the frame to the buffer and returns a slice of the buffer containing the frame. + /// + /// This function will NOT append the frame to the Vec, but rather replace the current bytes + /// with the frame's serialized bytes. pub fn write<'a>(&mut self, buf: &'a mut Vec) -> &'a [u8] { fn reserve_enough(buf: &mut Vec, len: usize) { if buf.len() < len { diff --git a/src/lib.rs b/src/lib.rs index d699a70..b872c44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -259,7 +259,9 @@ pub(crate) struct WriteHalf { vectored: bool, auto_apply_mask: bool, writev_threshold: usize, - buffer: BytesMut, + buf: Vec, + buf_offset: usize, + frame_size: usize, } pub(crate) struct ReadHalf { @@ -987,7 +989,9 @@ impl WriteHalf { auto_apply_mask: true, vectored: true, writev_threshold: 1024, - buffer: BytesMut::with_capacity(1024), + buf: Vec::with_capacity(1024), + frame_size: 0, + buf_offset: 0, } } @@ -1029,23 +1033,15 @@ impl WriteHalf { /// call start_send_frame. pub fn poll_ready( &mut self, - stream: &mut S, - cx: &mut Context<'_>, + _stream: &mut S, + _cx: &mut Context<'_>, ) -> Poll> where S: AsyncWrite + Unpin, { - if self.buffer.is_empty() { - Poll::Ready(Ok(())) - } else { - let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; - if written == 0 { - Poll::Ready(Err(WebSocketError::ConnectionClosed)) - } else { - self.buffer.advance(written); - Poll::Ready(Ok(())) - } - } + debug_assert_eq!(self.buf_offset, self.frame_size); + // underlying stream is supposed to always be ready + Poll::Ready(Ok(())) } pub fn start_send_frame<'a>( @@ -1066,8 +1062,8 @@ impl WriteHalf { // TODO(dgrr): Cap max payload size with a user setting? - frame.fmt_head(&mut self.buffer); - self.buffer.extend_from_slice(&frame.payload); + self.buf_offset = 0; + self.frame_size = frame.write(&mut self.buf).len(); Ok(()) } @@ -1081,16 +1077,17 @@ impl WriteHalf { S: AsyncWrite + Unpin, { // flush the buffer - while !self.buffer.is_empty() { - let written = ready!(pin!(&mut *stream).poll_write(cx, &self.buffer))?; + while self.buf_offset < self.frame_size { + let written = ready!(pin!(&mut *stream) + .poll_write(cx, &self.buf[self.buf_offset..self.frame_size]))?; if written == 0 { return Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - self.buffer.advance(written); + self.buf_offset += written; } // flush the stream - Poll::Ready(ready!(pin!(&mut *stream).poll_flush(cx)).map_err(Into::into)) + pin!(&mut *stream).poll_flush(cx).map_err(Into::into) } } From 28c331ab5a91cdfd37758a66d4f69019531a9b23 Mon Sep 17 00:00:00 2001 From: Dario Date: Sat, 6 Jul 2024 23:52:42 +0200 Subject: [PATCH 21/23] Remove underscore from used variable names --- src/lib.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b8c37e4..f50dca7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -166,7 +166,6 @@ pub mod upgrade; use bytes::Buf; -use bytes::Bytes; use bytes::BytesMut; use frame::MAX_HEAD_SIZE; use futures::task::AtomicWaker; @@ -1046,8 +1045,8 @@ impl WriteHalf { /// call start_send_frame. pub fn poll_ready( &mut self, - _stream: &mut S, - _cx: &mut Context<'_>, + stream: &mut S, + cx: &mut Context<'_>, ) -> Poll> where S: AsyncWrite + Unpin, From 7656018e2d0b9076f5e616efcdd9468da0237c67 Mon Sep 17 00:00:00 2001 From: Dario Date: Sun, 7 Jul 2024 00:43:36 +0200 Subject: [PATCH 22/23] Added test to check that simple and vectored serialization do not conflict --- src/fragment.rs | 4 +- src/lib.rs | 106 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 101 insertions(+), 9 deletions(-) diff --git a/src/fragment.rs b/src/fragment.rs index 091fb03..7a76635 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -105,7 +105,7 @@ impl<'f, S> FragmentCollector { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; let is_closed = self.write_half.closed; if let Some(obligated_send) = obligated_send { if !is_closed { @@ -173,7 +173,7 @@ impl<'f, S> FragmentCollectorRead { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; if let Some(frame) = obligated_send { let res = send_fn(frame).await; res.map_err(|e| WebSocketError::SendError(e.into()))?; diff --git a/src/lib.rs b/src/lib.rs index f50dca7..d3a5a8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -273,6 +273,7 @@ struct WriteBuffer { // where in the write_buffer this payload should be inserted position: usize, read_head: usize, + // TODO(dgrr): add a lifetime instead of using 'static? payload: Payload<'static>, } @@ -396,7 +397,7 @@ impl<'f, S> WebSocketRead { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; if let Some(frame) = obligated_send { let res = send_fn(frame).await; res.map_err(|e| WebSocketError::SendError(e.into()))?; @@ -416,7 +417,7 @@ impl<'f, S> WebSocketRead { where S: AsyncRead + Unpin, { - self.read_half.poll_read_frame_inner(&mut self.stream, cx) + self.read_half.poll_read_frame(&mut self.stream, cx) } } @@ -719,7 +720,7 @@ impl<'f, S> WebSocket { { loop { let (res, obligated_send) = - ready!(self.read_half.poll_read_frame_inner(&mut self.stream, cx)); + ready!(self.read_half.poll_read_frame(&mut self.stream, cx)); let is_closed = self.write_half.closed; if let Some(frame) = obligated_send { @@ -767,18 +768,18 @@ impl ReadHalf { /// /// XXX: Do not expose this method to the public API. #[inline(always)] - pub(crate) async fn read_frame_inner<'f, S>( + pub(crate) async fn read_frame<'f, S>( &mut self, stream: &mut S, ) -> (Result>, WebSocketError>, Option>) where S: AsyncRead + Unpin, { - poll_fn(|cx| self.poll_read_frame_inner(stream, cx)).await + poll_fn(|cx| self.poll_read_frame(stream, cx)).await } /// Reads a frame from the Stream. - pub(crate) fn poll_read_frame_inner<'f, S>( + pub(crate) fn poll_read_frame<'f, S>( &mut self, stream: &mut S, cx: &mut Context<'_>, @@ -1063,7 +1064,7 @@ impl WriteHalf { &'a mut self, mut frame: Frame<'a>, ) -> Result<(), WebSocketError> { - // TODO: backpressure check? + // TODO(dario): backpressure check? tokio codec does it if self.role == Role::Client && self.auto_apply_mask { frame.mask(); @@ -1173,6 +1174,8 @@ impl WriteHalf { #[cfg(test)] mod tests { + use std::ops::Deref; + use super::*; const _: () = { @@ -1199,4 +1202,93 @@ mod tests { } assert_unsync::>(); }; + + #[tokio::test] + async fn test_contiguous_simple_and_vectored_writes() { + struct MockStream(Vec); + + impl AsyncRead for MockStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.0.is_empty() { + return Poll::Ready(Ok(())); + } + + let size_before = buf.filled().len(); + buf.put_slice(&this.0); + let diff = buf.filled().len() - size_before; + + this.0.drain(..diff); + + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().0.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + let simple_string = b"1234".to_vec(); + // copy this string more than 1024 times to trigger the vector writes + let long_string = b"A".repeat(1025); + + let mut stream = MockStream(vec![]); + let mut write_half = super::WriteHalf::after_handshake(Role::Server); + let mut read_half = super::ReadHalf::after_handshake(Role::Client); + + poll_fn(|cx| { + // write + assert!(write_half.poll_ready(&mut stream, cx).is_ready()); + // serialize both frames at the same time + assert!(write_half + .start_send_frame(Frame::text(Payload::Owned(simple_string.clone()))) + .is_ok()); + assert!(write_half + .start_send_frame(Frame::text(Payload::Owned(long_string.clone()))) + .is_ok()); + assert!(write_half.poll_flush(&mut stream, cx).is_ready()); + + // read + for body in [&simple_string, &long_string] { + let Poll::Ready((res, mandatory_send)) = + read_half.poll_read_frame(&mut stream, cx) + else { + unreachable!() + }; + + assert!(mandatory_send.is_none()); + + let frame = res.unwrap().unwrap(); + assert_eq!(frame.payload.deref(), body); + } + + Poll::Ready(()) + }) + .await; + } } From b960f1572ce9db7d5ebaae8190d1ad4144a9c556 Mon Sep 17 00:00:00 2001 From: Dario Date: Sun, 7 Jul 2024 12:02:05 +0200 Subject: [PATCH 23/23] WebSocket: Handle obligated send flushing with states to ensure delivery --- src/lib.rs | 151 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 126 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d3a5a8c..3f63389 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,6 +172,9 @@ use futures::task::AtomicWaker; use std::collections::VecDeque; use std::future::poll_fn; use std::io::IoSlice; +use std::mem; +use std::ops::Deref; +use std::ops::DerefMut; use std::pin::pin; use std::sync::Arc; use std::task::ready; @@ -473,9 +476,59 @@ impl<'f, S> WebSocketWrite { } } +/// Keep track of the state of the Stream +enum StreamState { + // reading from Stream + Reading(S), + // flushing obligated send + Flushing(S), + // keep the stream here just in case the user wants to access to it + Closed(S), + // used temporarily + None, +} + +impl Deref for StreamState { + type Target = S; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } +} + +impl DerefMut for StreamState { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } +} + +impl StreamState { + #[inline(always)] + fn into_inner(self) -> S { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } +} + /// WebSocket protocol implementation over an async stream. pub struct WebSocket { - stream: S, + stream: StreamState, write_half: WriteHalf, read_half: ReadHalf, waker: Arc, @@ -507,8 +560,8 @@ impl<'f, S> WebSocket { { let waker = Arc::new(WakerDemux::default()); Self { - stream, waker, + stream: StreamState::Reading(stream), write_half: WriteHalf::after_handshake(role), read_half: ReadHalf::after_handshake(role), } @@ -545,13 +598,13 @@ impl<'f, S> WebSocket { #[inline] pub fn into_inner(self) -> S { // self.write_half.into_inner().stream - self.stream + self.stream.into_inner() } /// Consumes the `WebSocket` and returns the underlying stream. #[inline] pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) { - (self.stream, self.read_half, self.write_half) + (self.stream.into_inner(), self.read_half, self.write_half) } /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. @@ -624,7 +677,10 @@ impl<'f, S> WebSocket { where S: AsyncRead + AsyncWrite + Unpin, { - self.write_half.write_frame(&mut self.stream, frame).await?; + self + .write_half + .write_frame(self.stream.deref_mut(), frame) + .await?; Ok(()) } @@ -671,9 +727,9 @@ impl<'f, S> WebSocket { S: AsyncWrite + Unpin, { self.waker.set_waker(ContextKind::Write, cx.waker()); - self - .waker - .with_context(|cx| self.write_half.poll_flush(&mut self.stream, cx)) + self.waker.with_context(|cx| { + self.write_half.poll_flush(self.stream.deref_mut(), cx) + }) } /// Reads a frame from the stream. @@ -719,27 +775,72 @@ impl<'f, S> WebSocket { S: AsyncRead + AsyncWrite + Unpin, { loop { - let (res, obligated_send) = - ready!(self.read_half.poll_read_frame(&mut self.stream, cx)); + match mem::replace(&mut self.stream, StreamState::None) { + StreamState::None => unreachable!(), + StreamState::Reading(mut stream) => { + let (res, obligated_send) = + match self.read_half.poll_read_frame(&mut stream, cx) { + Poll::Ready(res) => res, + Poll::Pending => { + self.stream = StreamState::Reading(stream); + break Poll::Pending; + } + }; - let is_closed = self.write_half.closed; - if let Some(frame) = obligated_send { - if !is_closed { - self.write_half.start_send_frame(frame)?; - self.waker.set_waker(ContextKind::Read, cx.waker()); - let res = self.waker.with_context(|cx| { - self.write_half.poll_flush(&mut self.stream, cx) - }); - ready!(res)?; - } - } + let is_closed = self.write_half.closed; + + macro_rules! try_send_obligated { + () => { + if let Some(frame) = obligated_send { + // if the write half didn't emit the close frame + if !is_closed { + self.write_half.start_send_frame(frame)?; + self.stream = StreamState::Flushing(stream); + } else { + self.stream = StreamState::Reading(stream); + } + } else { + self.stream = StreamState::Reading(stream); + } + }; + } - if let Some(frame) = res? { - if is_closed && frame.opcode != OpCode::Close { - return Poll::Ready(Err(WebSocketError::ConnectionClosed)); + if let Some(frame) = res? { + if is_closed && frame.opcode != OpCode::Close { + self.stream = StreamState::Closed(stream); + break Poll::Ready(Err(WebSocketError::ConnectionClosed)); + } + + try_send_obligated!(); + break Poll::Ready(Ok(frame)); + } + + try_send_obligated!(); } + StreamState::Flushing(mut stream) => { + self.waker.set_waker(ContextKind::Read, cx.waker()); - break Poll::Ready(Ok(frame)); + let res = self + .waker + .with_context(|cx| self.write_half.poll_flush(&mut stream, cx)); + match res { + Poll::Ready(ok) => { + self.stream = if self.is_closed() { + StreamState::Closed(stream) + } else { + StreamState::Reading(stream) + }; + ok?; + } + Poll::Pending => { + self.stream = StreamState::Flushing(stream); + } + } + } + StreamState::Closed(stream) => { + self.stream = StreamState::Closed(stream); + break Poll::Ready(Err(WebSocketError::ConnectionClosed)); + } } } }