From 425b612b35aa3da4e67546a6c3429b0a3ec417b5 Mon Sep 17 00:00:00 2001 From: Allan Zhang Date: Thu, 30 Nov 2023 21:33:52 -0500 Subject: [PATCH] Add extension read and write traits `tokio::io` has the handy AsyncReadExt and AsyncWriteExt traits, which were preivously used in tests. Now that we rely on Hyper IO, those had to be ported over. That is done in this PR. - AsyncReadExt -> ReadExt - AsyncWriteExt -> WriteExt - tokio internal Read -> ReadFut - tokio internal Write -> WriteFut The read timeout test is failing. I think it has to do with the `impl Future` for `ReadFut`. --- Cargo.toml | 1 + src/lib.rs | 98 +++----------------------- src/stream.rs | 185 +++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 171 insertions(+), 113 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4d094fc..b45edf8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ tokio = { version = "1.0.0", features = ["io-std", "io-util", "macros"] } hyper = { version = "1.0", features = ["http1"] } hyper-tls = "0.6" http-body-util = "0.1" +tokio-util = { version = "0.7", features = ["io"] } diff --git a/src/lib.rs b/src/lib.rs index 07b9d22..95dcc2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,11 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; +use hyper::rt::{Read, Write}; use tokio::time::timeout; -use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper::Uri; +use hyper_util::client::legacy::connect::{Connected, Connection}; use tower_service::Service; mod stream; @@ -31,8 +32,8 @@ pub struct TimeoutConnector { impl TimeoutConnector where T: Service + Send, - T::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin, - T::Future: Send + 'static, + T::Response: Read + Write + Send + Unpin, + T::Future: Send + 'static, T::Error: Into, { /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait @@ -67,7 +68,7 @@ where let read_timeout = self.read_timeout; let write_timeout = self.write_timeout; let connecting = self.connector.call(dst); - + let fut = async move { let mut stream = match connect_timeout { None => { @@ -132,12 +133,15 @@ where #[cfg(test)] mod tests { - use std::{io, error::Error}; use std::time::Duration; + use std::{error::Error, io}; use http_body_util::Empty; use hyper::body::Bytes; - use hyper_util::{client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor}; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; use super::TimeoutConnector; @@ -191,85 +195,3 @@ mod tests { } } } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/src/stream.rs b/src/stream.rs index c73a237..c811e45 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,19 +1,20 @@ //! Wrappers for applying timeouts to IO operations. -//! +//! //! This used to depend on [tokio-io-timeout]. After Hyper 1.0 introduced hyper-specific IO traits, this was rewritten to use hyper IO traits instead of tokio IO traits. -//! +//! //! These timeouts are analogous to the read and write timeouts on traditional blocking sockets. A timeout countdown is //! initiated when a read/write operation returns [`Poll::Pending`]. If a read/write does not return successfully before //! the countdown expires, an [`io::Error`] with a kind of [`TimedOut`](io::ErrorKind::TimedOut) is returned. #![warn(missing_docs)] -use hyper::rt::ReadBufCursor; -use hyper_util::client::legacy::connect::{Connection, Connected}; +use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio::time::{sleep_until, Instant, Sleep}; @@ -98,7 +99,7 @@ pin_project! { impl TimeoutReader where - R: hyper::rt::Read, + R: Read, { /// Returns a new `TimeoutReader` wrapping the specified reader. /// @@ -152,9 +153,9 @@ where } } -impl hyper::rt::Read for TimeoutReader +impl Read for TimeoutReader where - R: hyper::rt::Read, + R: Read, { fn poll_read( self: Pin<&mut Self>, @@ -171,9 +172,9 @@ where } } -impl hyper::rt::Write for TimeoutReader +impl Write for TimeoutReader where - R: hyper::rt::Write, + R: Write, { fn poll_write( self: Pin<&mut Self>, @@ -217,7 +218,7 @@ pin_project! { impl TimeoutWriter where - W: hyper::rt::Write, + W: Write, { /// Returns a new `TimeoutReader` wrapping the specified reader. /// @@ -271,9 +272,9 @@ where } } -impl hyper::rt::Write for TimeoutWriter +impl Write for TimeoutWriter where - W: hyper::rt::Write, + W: Write, { fn poll_write( self: Pin<&mut Self>, @@ -328,9 +329,9 @@ where } } -impl hyper::rt::Read for TimeoutWriter +impl Read for TimeoutWriter where - W: hyper::rt::Read, + W: Read, { fn poll_read( self: Pin<&mut Self>, @@ -352,7 +353,7 @@ pin_project! { impl TimeoutStream where - S: hyper::rt::Read + hyper::rt::Write, + S: Read + Write, { /// Returns a new `TimeoutStream` wrapping the specified stream. /// @@ -429,9 +430,9 @@ where } } -impl hyper::rt::Read for TimeoutStream +impl Read for TimeoutStream where - S: hyper::rt::Read + hyper::rt::Write, + S: Read + Write, { fn poll_read( self: Pin<&mut Self>, @@ -442,9 +443,9 @@ where } } -impl hyper::rt::Write for TimeoutStream +impl Write for TimeoutStream where - S: hyper::rt::Read + hyper::rt::Write, + S: Read + Write, { fn poll_write( self: Pin<&mut Self>, @@ -477,7 +478,7 @@ where impl Connection for TimeoutStream where - S: hyper::rt::Read + hyper::rt::Write + Connection + Unpin, + S: Read + Write + Connection + Unpin, { fn connected(&self) -> Connected { self.get_ref().connected() @@ -486,23 +487,157 @@ where impl Connection for Pin>> where - S: hyper::rt::Read + hyper::rt::Write + Connection + Unpin, + S: Read + Write + Connection + Unpin, { fn connected(&self) -> Connected { self.get_ref().connected() } } +pin_project! { + /// A future which can be used to easily read available number of bytes to fill + /// a buffer. Based on the internal [tokio::io::util::read::Read] + struct ReadFut<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut [u8], + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to read some bytes directly into the given `buf` in asynchronous +/// manner, returning a future type. +/// +/// The returned future will resolve to both the I/O stream and the buffer +/// as well as the number of bytes read once the read operation is completed. +fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R> +where + R: Read + Unpin + ?Sized, +{ + ReadFut { + reader, + buf, + _pin: PhantomPinned, + } +} + +impl Future for ReadFut<'_, R> +where + R: Read + Unpin + ?Sized, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + let mut buf = ReadBuf::new(me.buf); + ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +trait ReadExt: Read { + /// Pulls some bytes from this source into the specified buffer, + /// returning how many bytes were read. + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> + where + Self: Unpin, + { + read(self, buf) + } +} + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`.- + struct WriteFut<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W> +where + W: Write + Unpin + ?Sized, +{ + WriteFut { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl Future for WriteFut<'_, W> +where + W: Write + Unpin + ?Sized, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + Pin::new(&mut *me.writer).poll_write(cx, me.buf) + } +} + +trait WriteExt: Write { + /// Writes a buffer into this writer, returning how many bytes were + /// written. + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> + where + Self: Unpin, + { + write(self, src) + } +} + +impl ReadExt for Pin<&mut TimeoutReader> +where + R: Read, +{ + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> { + read(self, buf) + } +} + +impl WriteExt for Pin<&mut TimeoutWriter> +where + W: Write, +{ + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> { + write(self, src) + } +} + +impl ReadExt for Pin<&mut TimeoutStream> +where + S: Read + Write, +{ + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> { + read(self, buf) + } +} + +impl WriteExt for Pin<&mut TimeoutStream> +where + S: Read + Write, +{ + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> { + write(self, src) + } +} + #[cfg(test)] mod test { use super::*; + use hyper_util::rt::TokioIo; use std::io::Write; use std::net::TcpListener; use std::thread; - use hyper_util::rt::TokioIo; use tokio::net::TcpStream; use tokio::pin; - + pin_project! { struct DelayStream { #[pin] @@ -518,7 +653,7 @@ mod test { } } - impl hyper::rt::Read for DelayStream { + impl Read for DelayStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context,