diff --git a/.cargo/config.toml b/.cargo/config.toml index 99a616d0..a4626bd1 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,4 @@ -[target.x86_64-unknown-linux-gnu] -rustflags = ["-C", "linker=clang", "-C", "link-arg=-fuse-ld=lld"] +[target.'cfg(any())'] +rustflags = [ + "-Wunused-crate-dependencies", +] diff --git a/Cargo.lock b/Cargo.lock index 6c031402..1f61c1ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,6 +77,10 @@ dependencies = [ "paste", ] +[[package]] +name = "b-x" +version = "1.0.0" + [[package]] name = "backtrace" version = "0.3.71" @@ -137,9 +141,8 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" name = "buffet" version = "0.2.1" dependencies = [ + "b-x", "bytemuck", - "color-eyre", - "eyre", "http", "io-uring", "libc", @@ -736,6 +739,7 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" name = "httpwg" version = "0.2.1" dependencies = [ + "b-x", "buffet", "bytes", "enumflags2", @@ -790,9 +794,11 @@ dependencies = [ name = "httpwg-loona" version = "0.2.0" dependencies = [ + "b-x", "buffet", "codspeed-criterion-compat", "color-eyre", + "eyre", "loona", "tokio", "tracing", @@ -1008,14 +1014,13 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" name = "loona" version = "0.2.1" dependencies = [ + "b-x", "buffet", "byteorder", "bytes", "cargo-husky", "codspeed-criterion-compat", - "color-eyre", "criterion", - "eyre", "futures-util", "http", "httparse", diff --git a/crates/b-x/Cargo.toml b/crates/b-x/Cargo.toml new file mode 100644 index 00000000..3161858f --- /dev/null +++ b/crates/b-x/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "b-x" +version = "1.0.0" +edition = "2021" +license = "MIT OR Apache-2.0" +repository = "https://github.com/bearcove/loona" +documentation = "https://docs.rs/b-x" +readme = "README.md" +description = """ +The stupidest boxed error ever +""" +rust-version = "1.75.0" + +[dependencies] + +[features] +default = [] diff --git a/crates/b-x/README.md b/crates/b-x/README.md new file mode 100644 index 00000000..a8881bdc --- /dev/null +++ b/crates/b-x/README.md @@ -0,0 +1,9 @@ +# b-x + +b-x provides the stupidest boxed error ever. + +When you don't want `eyre`, you don't want `thiserror`, you don't want `anyhow`, +you want much, much less. Something that just implements `std::error::Error`. + +It's not even Send. You have to call `.bx()` on results and/or errors via extension +traits. It's so stupid. diff --git a/crates/b-x/src/lib.rs b/crates/b-x/src/lib.rs new file mode 100644 index 00000000..34dca8c6 --- /dev/null +++ b/crates/b-x/src/lib.rs @@ -0,0 +1,112 @@ +use std::{error::Error as StdError, fmt}; + +/// The stupidest box error ever. It's not even Send. +/// +/// It has `From` implementations for some libstd error types, +/// you can derive `From` for your own error types +/// with [make_bxable!] +pub struct BX(Box); + +/// A type alias where `E` defaults to `BX`. +pub type Result = std::result::Result; + +impl BX { + /// Create a new `BX` from an `E`. + pub fn from_err(e: impl StdError + 'static) -> Self { + Self(e.into()) + } + + /// Create a new `BX` from a boxed `E`. + pub fn from_boxed(e: Box) -> Self { + Self(e) + } + + /// Create a new `BX` from a String + pub fn from_string(s: String) -> Self { + Self(s.into()) + } +} + +pub fn box_error(e: impl StdError + 'static) -> BX { + BX::from_err(e) +} + +/// Adds `bx() -> BX` to error types +pub trait BxForErrors { + fn bx(self) -> BX; +} + +impl BxForErrors for E { + fn bx(self) -> BX { + BX::from_err(self) + } +} + +/// Adds `bx() -> Result` to result types +pub trait BxForResults { + fn bx(self) -> Result; +} + +impl BxForResults for Result { + fn bx(self) -> Result { + self.map_err(BX::from_err) + } +} + +impl fmt::Debug for BX { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for BX { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl StdError for BX { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.0.source() + } +} + +/// Implements `From` for `BX` for your own type. +#[macro_export] +macro_rules! make_bxable { + ($ty:ty) => { + impl From<$ty> for $crate::BX { + fn from(e: $ty) -> Self { + $crate::BX::from_err(e) + } + } + }; +} + +make_bxable!(std::io::Error); +make_bxable!(std::fmt::Error); +make_bxable!(std::str::Utf8Error); +make_bxable!(std::string::FromUtf8Error); +make_bxable!(std::string::FromUtf16Error); +make_bxable!(std::num::ParseIntError); +make_bxable!(std::num::ParseFloatError); +make_bxable!(std::num::TryFromIntError); +make_bxable!(std::array::TryFromSliceError); +make_bxable!(std::char::ParseCharError); +make_bxable!(std::net::AddrParseError); +make_bxable!(std::time::SystemTimeError); +make_bxable!(std::env::VarError); +make_bxable!(std::sync::mpsc::RecvError); +make_bxable!(std::sync::mpsc::TryRecvError); +make_bxable!(std::sync::mpsc::SendError>); +make_bxable!(std::sync::PoisonError>); + +#[macro_export] +macro_rules! bail { + ($err:expr) => { + return Err($crate::BX::from_err($err)); + }; + ($fmt:expr, $($arg:tt)*) => { + return Err($crate::BX::from_string(format!($fmt, $($arg)*))); + }; +} diff --git a/crates/buffet/Cargo.toml b/crates/buffet/Cargo.toml index 696431a3..349703e8 100644 --- a/crates/buffet/Cargo.toml +++ b/crates/buffet/Cargo.toml @@ -21,7 +21,6 @@ miri = [] [dependencies] bytemuck = { version = "1.16.3", features = ["extern_crate_std"] } -eyre = "0.6.12" http = "1.1.0" libc = "0.2.155" memchr = "2.7.4" @@ -41,11 +40,11 @@ tokio = { version = "1.39.2", features = [ ] } tracing = "0.1.40" nix = "0.29.0" +b-x = { version = "1.0.0", path = "../b-x" } [target.'cfg(target_os = "linux")'.dependencies] luring = { path = "../luring", version = "0.1.0", optional = true } io-uring = { version = "0.6.4", optional = true } [dev-dependencies] -color-eyre = "0.6.3" pretty_assertions = "1.4.0" diff --git a/crates/buffet/src/bufpool.rs b/crates/buffet/src/bufpool.rs index e99b17ea..7509f898 100644 --- a/crates/buffet/src/bufpool.rs +++ b/crates/buffet/src/bufpool.rs @@ -347,7 +347,7 @@ mod tests { } #[test] - fn freeze_test() -> eyre::Result<()> { + fn freeze_test() { crate::bufpool::initialize_allocator().unwrap(); let total_bufs = num_free(); @@ -372,12 +372,10 @@ mod tests { drop(b2); assert_eq!(total_bufs, num_free()); - - Ok(()) } #[test] - fn split_test() -> eyre::Result<()> { + fn split_test() { crate::bufpool::initialize_allocator().unwrap(); let total_bufs = num_free(); @@ -391,7 +389,5 @@ mod tests { assert_eq!(&b[..6], b"jacket"); drop((a, b)); - - Ok(()) } } diff --git a/crates/buffet/src/bufpool/privatepool.rs b/crates/buffet/src/bufpool/privatepool.rs index 9b80e140..7d1c2105 100644 --- a/crates/buffet/src/bufpool/privatepool.rs +++ b/crates/buffet/src/bufpool/privatepool.rs @@ -7,6 +7,7 @@ use super::BufMut; pub type Result = std::result::Result; #[derive(thiserror::Error, Debug)] +#[non_exhaustive] pub enum Error { #[error("could not mmap buffer")] Mmap(#[from] std::io::Error), @@ -18,6 +19,8 @@ pub enum Error { DoesNotFit, } +b_x::make_bxable!(Error); + thread_local! { static POOL: Pool = const { Pool::new() }; } diff --git a/crates/buffet/src/net/net_uring.rs b/crates/buffet/src/net/net_uring.rs index 3aab84ee..2b77fa42 100644 --- a/crates/buffet/src/net/net_uring.rs +++ b/crates/buffet/src/net/net_uring.rs @@ -206,11 +206,11 @@ mod tests { #[test] fn test_accept() { - color_eyre::install().unwrap(); - - async fn test_accept_inner() -> color_eyre::Result<()> { - let listener = super::TcpListener::bind("127.0.0.1:0".parse().unwrap()).await?; - let addr = listener.local_addr()?; + async fn test_accept_inner() { + let listener = super::TcpListener::bind("127.0.0.1:0".parse().unwrap()) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); println!("listening on {}", addr); std::thread::spawn(move || { @@ -231,26 +231,23 @@ mod tests { println!("[client] wrote: hello"); }); - let (stream, addr) = listener.accept().await?; + let (stream, addr) = listener.accept().await.unwrap(); println!("accepted connection!, addr={addr:?}"); let (mut r, mut w) = stream.into_halves(); - // write bye - w.write_all_owned("howdy").await?; + w.write_all_owned("howdy").await.unwrap(); let buf = vec![0u8; 1024]; let (res, buf) = r.read_owned(buf).await; - let n = res?; + let n = res.unwrap(); let slice = &buf[..n]; println!( "read {} bytes: {:?}, as string: {:?}", n, slice, - std::str::from_utf8(slice)? + std::str::from_utf8(slice).unwrap() ); - - Ok(()) } - crate::start(async move { test_accept_inner().await.unwrap() }); + crate::start(async move { test_accept_inner().await }); } } diff --git a/crates/buffet/src/roll.rs b/crates/buffet/src/roll.rs index 24fea645..b20dc031 100644 --- a/crates/buffet/src/roll.rs +++ b/crates/buffet/src/roll.rs @@ -1288,27 +1288,32 @@ mod tests { fn test_roll_iobuf() { crate::bufpool::initialize_allocator().unwrap(); + use b_x::{BxForResults, BX}; + use crate::{ - io::{IntoHalves, ReadOwned, WriteOwned}, + io::IntoHalves, net::{TcpListener, TcpStream}, + ReadOwned, WriteOwned, }; - async fn test_roll_iobuf_inner(mut rm: RollMut) -> eyre::Result<()> { + async fn test_roll_iobuf_inner(mut rm: RollMut) -> b_x::Result<()> { rm.put(b"hello").unwrap(); let roll = rm.take_all(); - let ln = TcpListener::bind("127.0.0.1:0".parse()?).await?; - let local_addr = ln.local_addr()?; + let ln = TcpListener::bind("127.0.0.1:0".parse().unwrap()) + .await + .unwrap(); + let local_addr = ln.local_addr().unwrap(); let send_fut = async move { - let stream = TcpStream::connect(local_addr).await?; + let stream = TcpStream::connect(local_addr).await.bx()?; let (_stream_r, mut stream_w) = IntoHalves::into_halves(stream); stream_w.write_all_owned(roll).await?; - Ok::<_, eyre::Report>(()) + Ok::<_, BX>(()) }; let recv_fut = async move { - let (stream, addr) = ln.accept().await?; + let (stream, addr) = ln.accept().await.bx()?; let (mut stream_r, _stream_w) = IntoHalves::into_halves(stream); println!("Accepted connection from {addr}"); @@ -1319,7 +1324,7 @@ mod tests { assert_eq!(&buf[..n], b"hello"); - Ok::<_, eyre::Report>(()) + Ok::<_, BX>(()) }; tokio::try_join!(send_fut, recv_fut)?; diff --git a/crates/httpwg-loona/Cargo.toml b/crates/httpwg-loona/Cargo.toml index d591780a..1d79233c 100644 --- a/crates/httpwg-loona/Cargo.toml +++ b/crates/httpwg-loona/Cargo.toml @@ -27,6 +27,8 @@ buffet = { version = "0.2.1", path = "../buffet" } tracing = { version = "0.1.40", features = ["release_max_level_debug"] } tracing-subscriber = "0.3.18" tokio = { version = "1.39.2", features = ["macros", "sync", "process"] } +eyre = { version = "0.6.12", default-features = false } +b-x = { version = "1.0.0", path = "../b-x" } [dev-dependencies] codspeed-criterion-compat = "2.6.0" diff --git a/crates/httpwg-loona/src/lib.rs b/crates/httpwg-loona/src/lib.rs index c60fccd8..1974c226 100644 --- a/crates/httpwg-loona/src/lib.rs +++ b/crates/httpwg-loona/src/lib.rs @@ -1,11 +1,12 @@ +use b_x::{BxForResults, BX}; use std::{cell::RefCell, rc::Rc}; use tokio::{process::Command, sync::oneshot}; use buffet::{IntoHalves, RollMut}; -use color_eyre::eyre; use loona::{ http::{self, StatusCode}, Body, BodyChunk, Encoder, ExpectResponseHeaders, Responder, Response, ResponseDone, + ServerDriver, }; #[derive(Debug, Clone, Copy)] @@ -164,13 +165,18 @@ pub fn do_main(port: u16, proto: Proto, mode: Mode) { struct TestDriver; -impl loona::ServerDriver for TestDriver { - async fn handle( +impl ServerDriver for TestDriver +where + OurEncoder: Encoder, +{ + type Error = BX; + + async fn handle( &self, _req: loona::Request, req_body: &mut impl Body, - mut res: Responder, - ) -> eyre::Result> { + mut res: Responder, + ) -> Result, Self::Error> { // if the client sent `expect: 100-continue`, we must send a 100 status code if let Some(h) = _req.headers.get(http::header::EXPECT) { if &h[..] == b"100-continue" { @@ -185,7 +191,7 @@ impl loona::ServerDriver for TestDriver { // then read the full request body let mut req_body_len = 0; loop { - let chunk = req_body.next_chunk().await?; + let chunk = req_body.next_chunk().await.bx()?; match chunk { BodyChunk::Done { trailers } => { // yey diff --git a/crates/httpwg/Cargo.toml b/crates/httpwg/Cargo.toml index fccf7268..f6198fbe 100644 --- a/crates/httpwg/Cargo.toml +++ b/crates/httpwg/Cargo.toml @@ -22,3 +22,4 @@ futures-util = "0.3.30" pretty-hex = "0.4.1" tokio = { version = "1.39.2", features = ["time"] } tracing = "0.1.40" +b-x = { version = "1.0.0", path = "../b-x" } diff --git a/crates/loona-h2/src/lib.rs b/crates/loona-h2/src/lib.rs index a5bfa122..29c57d27 100644 --- a/crates/loona-h2/src/lib.rs +++ b/crates/loona-h2/src/lib.rs @@ -404,6 +404,9 @@ impl Frame { } impl IntoPiece for Frame { + /// FIXME: not support happy about this returning an `std::io::Error` + /// really the only way this can fail is if we cannot allocate memory, + /// because we're never actually doing any I/O here. fn into_piece(self, scratch: &mut RollMut) -> std::io::Result { debug_assert_eq!(scratch.len(), 0); self.write_into(&mut *scratch)?; @@ -743,6 +746,7 @@ impl Settings { } #[derive(thiserror::Error, Debug)] +#[non_exhaustive] pub enum SettingsError { #[error("ENABLE_PUSH setting is supposed to be either 0 or 1, got {actual}")] InvalidEnablePushValue { actual: u32 }, diff --git a/crates/loona-hpack/src/decoder.rs b/crates/loona-hpack/src/decoder.rs index 5d7b0425..441ef15e 100644 --- a/crates/loona-hpack/src/decoder.rs +++ b/crates/loona-hpack/src/decoder.rs @@ -200,6 +200,7 @@ impl FieldRepresentation { /// Represents all errors that can be encountered while decoding an /// integer. #[derive(PartialEq, Copy, Clone, Debug, thiserror::Error)] +#[non_exhaustive] pub enum IntegerDecodingError { /// 5.1. specifies that "excessively large integer decodings" MUST be /// considered an error (whether the size is the number of octets or @@ -221,6 +222,7 @@ pub enum IntegerDecodingError { } #[derive(PartialEq, Copy, Clone, Debug, thiserror::Error)] +#[non_exhaustive] pub enum StringDecodingError { #[error("Not enough octets in the buffer")] NotEnoughOctets, @@ -231,6 +233,7 @@ pub enum StringDecodingError { /// Represents all errors that can be encountered while performing the decoding /// of an HPACK header set. #[derive(PartialEq, Copy, Clone, Debug, thiserror::Error)] +#[non_exhaustive] pub enum DecoderError { #[error("Header index out of bounds")] HeaderIndexOutOfBounds, diff --git a/crates/loona-hpack/src/huffman.rs b/crates/loona-hpack/src/huffman.rs index 44fd9904..0d1e3758 100644 --- a/crates/loona-hpack/src/huffman.rs +++ b/crates/loona-hpack/src/huffman.rs @@ -26,6 +26,7 @@ impl HuffmanCodeSymbol { } #[derive(thiserror::Error, PartialEq, Copy, Clone, Debug)] +#[non_exhaustive] pub enum HuffmanDecoderError { /// Any padding strictly larger than 7 bits MUST be interpreted as an error #[error("Padding too large")] diff --git a/crates/loona/Cargo.toml b/crates/loona/Cargo.toml index e94c6cac..68e7033c 100644 --- a/crates/loona/Cargo.toml +++ b/crates/loona/Cargo.toml @@ -24,7 +24,6 @@ harness = false [dependencies] byteorder = "1.5.0" -eyre = { version = "0.6.12", default-features = false } futures-util = "0.3.30" buffet = { version = "0.2.1", path = "../buffet" } loona-hpack = { version = "0.3.2", path = "../loona-hpack" } @@ -41,6 +40,7 @@ thiserror = { version = "1.0.63", default-features = false } tokio = { version = "1.39.2", features = ["macros", "sync"] } tracing = { version = "0.1.40", default-features = false } loona-h2 = { version = "0.2.1", path = "../loona-h2" } +b-x = { version = "1.0.0", path = "../b-x" } [dev-dependencies] buffet = { version = "0.2.1", path = "../buffet" } @@ -66,7 +66,6 @@ futures-util = { version = "0.3.30", default-features = false, features = [ ] } libc = "0.2.155" httpwg = { path = "../httpwg" } -color-eyre = "0.6.3" httpwg-macros = { version = "0.2.1", path = "../httpwg-macros" } cargo-husky = { version = "1", features = ["user-hooks"] } criterion = "0.5.1" diff --git a/crates/loona/examples/tls/linux.rs b/crates/loona/examples/tls/linux.rs index e48fb803..43750881 100644 --- a/crates/loona/examples/tls/linux.rs +++ b/crates/loona/examples/tls/linux.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -use color_eyre::eyre; +use b_x::{BxForResults, BX}; use http::Version; use ktls::CorkStream; use loona::{ @@ -19,12 +19,11 @@ use tokio::net::TcpListener; use tracing::{debug, info}; use tracing_subscriber::EnvFilter; -pub(crate) fn main() -> eyre::Result<()> { +pub(crate) fn main() { loona::buffet::start(async_main()) } -async fn async_main() -> eyre::Result<()> { - color_eyre::install()?; +async fn async_main() { tracing_subscriber::fmt::fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), @@ -33,10 +32,10 @@ async fn async_main() -> eyre::Result<()> { if std::env::args().any(|a| a == "--get") { sample_http_request().await.unwrap(); - return Ok(()); + return; } - let certified_key = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?; + let certified_key = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); let crt = certified_key.cert.der(); let key = certified_key.key_pair.serialize_der(); @@ -55,14 +54,20 @@ async fn async_main() -> eyre::Result<()> { let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); let acceptor = Rc::new(acceptor); - let pt_h1_ln = TcpListener::bind("[::]:7080").await?; - info!("Serving plaintext HTTP/1.1 on {}", pt_h1_ln.local_addr()?); + let pt_h1_ln = TcpListener::bind("[::]:7080").await.unwrap(); + info!( + "Serving plaintext HTTP/1.1 on {}", + pt_h1_ln.local_addr().unwrap() + ); - let pt_h2_ln = TcpListener::bind("[::]:7082").await?; - info!("Serving plaintext HTTP/2 on {}", pt_h2_ln.local_addr()?); + let pt_h2_ln = TcpListener::bind("[::]:7082").await.unwrap(); + info!( + "Serving plaintext HTTP/2 on {}", + pt_h2_ln.local_addr().unwrap() + ); - let tls_ln = TcpListener::bind("[::]:7443").await?; - info!("Serving HTTPS on {}", tls_ln.local_addr()?); + let tls_ln = TcpListener::bind("[::]:7443").await.unwrap(); + info!("Serving HTTPS on {}", tls_ln.local_addr().unwrap()); let h1_conf = Rc::new(h1::ServerConf::default()); let h2_conf = Rc::new(h2::ServerConf::default()); @@ -83,8 +88,6 @@ async fn async_main() -> eyre::Result<()> { } }); } - - Ok::<_, color_eyre::Report>(()) } }; @@ -104,8 +107,6 @@ async fn async_main() -> eyre::Result<()> { } }); } - - Ok::<_, color_eyre::Report>(()) } }; @@ -124,12 +125,9 @@ async fn async_main() -> eyre::Result<()> { } }); } - - Ok::<_, color_eyre::Report>(()) }; - tokio::try_join!(pt_h1_loop, pt_h2_loop, tls_loop)?; - Ok(()) + tokio::join!(pt_h1_loop, pt_h2_loop, tls_loop); } enum Proto { @@ -141,7 +139,7 @@ async fn handle_plaintext_conn( stream: tokio::net::TcpStream, remote_addr: std::net::SocketAddr, proto: Proto, -) -> Result<(), color_eyre::Report> { +) -> Result<(), BX> { info!("Accepted connection from {remote_addr}"); let buf = RollMut::alloc()?; @@ -168,7 +166,7 @@ async fn handle_tls_conn( remote_addr: std::net::SocketAddr, h1_conf: Rc, h2_conf: Rc, -) -> Result<(), color_eyre::Report> { +) -> b_x::Result<()> { info!("Accepted connection from {remote_addr}"); let stream = CorkStream::new(stream); let stream = acceptor.accept(stream).await?; @@ -179,7 +177,7 @@ async fn handle_tls_conn( .and_then(|p| std::str::from_utf8(p).ok().map(|s| s.to_string())); debug!(?alpn_proto, "Performed TLS handshake"); - let stream = ktls::config_ktls_server(stream).await?; + let stream = ktls::config_ktls_server(stream).await.bx()?; debug!("Set up kTLS"); let (drained, stream) = stream.into_raw(); @@ -202,7 +200,9 @@ async fn handle_tls_conn( info!("Using HTTP/1.1"); loona::h1::serve(stream.into_halves(), h1_conf, buf, driver).await?; } - Some(other) => return Err(eyre::eyre!("Unsupported ALPN protocol: {}", other)), + Some(other) => { + b_x::bail!("Unsupported ALPN protocol: {}", other) + } } Ok(()) @@ -210,13 +210,18 @@ async fn handle_tls_conn( struct SDriver {} -impl ServerDriver for SDriver { - async fn handle( +impl ServerDriver for SDriver +where + OurEncoder: Encoder, +{ + type Error = BX; + + async fn handle( &self, mut req: loona::Request, req_body: &mut impl Body, - respond: Responder, - ) -> eyre::Result> { + respond: Responder, + ) -> b_x::Result> { info!("Handling {:?} {}", req.method, req.uri); let addr = "httpbingo.org:80" @@ -240,20 +245,21 @@ impl ServerDriver for SDriver { } } -struct CDriver +struct CDriver where - E: Encoder, + OurEncoder: Encoder, { - respond: Responder, + respond: Responder, } -impl h1::ClientDriver for CDriver +impl h1::ClientDriver for CDriver where - E: Encoder, + OurEncoder: Encoder, { - type Return = Responder; + type Error = BX; + type Return = Responder; - async fn on_informational_response(&mut self, _res: loona::Response) -> eyre::Result<()> { + async fn on_informational_response(&mut self, _res: loona::Response) -> b_x::Result<()> { // ignore informational responses Ok(()) @@ -263,7 +269,7 @@ where self, res: loona::Response, body: &mut impl Body, - ) -> eyre::Result { + ) -> b_x::Result { info!("Client got final response: {}", res.status); let respond = self.respond; @@ -271,7 +277,7 @@ where let trailers = loop { debug!("Reading from body {body:?}"); - match body.next_chunk().await? { + match body.next_chunk().await.bx()? { loona::BodyChunk::Chunk(chunk) => { debug!("Client got chunk of len {}", chunk.len()); @@ -283,16 +289,17 @@ where } }; - respond.finish_body(trailers).await + respond.finish_body(trailers).await.bx() } } struct SampleCDriver {} impl h1::ClientDriver for SampleCDriver { + type Error = BX; type Return = (); - async fn on_informational_response(&mut self, _res: loona::Response) -> eyre::Result<()> { + async fn on_informational_response(&mut self, _res: loona::Response) -> b_x::Result<()> { // ignore informational responses Ok(()) @@ -302,12 +309,12 @@ impl h1::ClientDriver for SampleCDriver { self, res: loona::Response, body: &mut impl Body, - ) -> eyre::Result { + ) -> b_x::Result { info!("Client got final response: {}", res.status); loop { debug!("Reading from body {body:?}"); - match body.next_chunk().await? { + match body.next_chunk().await.bx()? { loona::BodyChunk::Chunk(chunk) => { debug!("Client got chunk of len {}", chunk.len()); } @@ -320,7 +327,7 @@ impl h1::ClientDriver for SampleCDriver { } } -async fn sample_http_request() -> color_eyre::Result<()> { +async fn sample_http_request() -> b_x::Result<()> { info!("Doing sample HTTP request to httpbingo"); let addr = "httpbingo.org:80" diff --git a/crates/loona/examples/tls/main.rs b/crates/loona/examples/tls/main.rs index c8a68806..a8f5ec68 100644 --- a/crates/loona/examples/tls/main.rs +++ b/crates/loona/examples/tls/main.rs @@ -9,6 +9,6 @@ mod non_linux; #[cfg(not(target_os = "linux"))] use non_linux as inner; -fn main() -> color_eyre::Result<()> { +fn main() { inner::main() } diff --git a/crates/loona/examples/tls/non_linux.rs b/crates/loona/examples/tls/non_linux.rs index 5e81f440..c670ac17 100644 --- a/crates/loona/examples/tls/non_linux.rs +++ b/crates/loona/examples/tls/non_linux.rs @@ -1,3 +1,3 @@ -pub(crate) fn main() -> color_eyre::Result<()> { +pub(crate) fn main() { panic!("The loona TLS example is only supported on Linux"); } diff --git a/crates/loona/src/error.rs b/crates/loona/src/error.rs new file mode 100644 index 00000000..8152d840 --- /dev/null +++ b/crates/loona/src/error.rs @@ -0,0 +1,56 @@ +use std::error::Error as StdError; +use std::fmt; + +use b_x::BX; + +use crate::h2::types::H2ConnectionError; + +#[non_exhaustive] +#[derive(Debug, thiserror::Error)] +pub enum ServeError { + /// An error occurred while writing to the downstream + #[error("Error writing to downstream: {0}")] + DownstreamWrite(#[from] std::io::Error), + + /// The server driver errored out + #[error("Server driver error: {0:?}")] + Driver(DriverError), + + /// HTTP/1.1 response body was not drained by response + /// handler before the client closed the connection + #[error("HTTP/1.1 response body was not drained before client closed connection")] + ResponseHandlerBodyNotDrained, + + /// An error occurred while handling an HTTP/2 connection + #[error("HTTP/2 connection error: {0}")] + H2ConnectionError(#[from] H2ConnectionError), + + /// An error occurred during memory allocation + #[error("Memory allocation error: {0}")] + Alloc(#[from] buffet::bufpool::Error), +} + +impl From> for BX +where + DriverError: std::error::Error + 'static, +{ + fn from(e: ServeError) -> Self { + BX::from_err(e) + } +} + +pub struct NeverError; + +impl fmt::Debug for NeverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("NeverError") + } +} + +impl fmt::Display for NeverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("NeverError") + } +} + +impl StdError for NeverError {} diff --git a/crates/loona/src/h1/body.rs b/crates/loona/src/h1/body.rs index ea7c0f98..7cd38e7d 100644 --- a/crates/loona/src/h1/body.rs +++ b/crates/loona/src/h1/body.rs @@ -2,7 +2,7 @@ use std::fmt; use tracing::debug; -use crate::{util::read_and_parse, Body, BodyChunk, BodyErrorReason}; +use crate::{util::read_and_parse, Body, BodyChunk, BodyError}; use buffet::{Piece, PieceList, ReadOwned, RollMut, WriteOwned}; /// An HTTP/1.1 body, either chunked or content-length. @@ -73,7 +73,9 @@ impl H1Body { } } -impl Body for H1Body { +impl Body for H1Body { + type Error = BodyError; + fn content_len(&self) -> Option { match &self.state { Decoder::Chunked(_) => None, @@ -81,7 +83,7 @@ impl Body for H1Body { } } - async fn next_chunk(&mut self) -> eyre::Result { + async fn next_chunk(&mut self) -> Result { if self.buf.is_none() { return Ok(BodyChunk::Done { trailers: None }); } @@ -107,7 +109,7 @@ impl ContentLengthDecoder { &mut self, buf_slot: &mut Option, transport: &mut impl ReadOwned, - ) -> eyre::Result { + ) -> Result { let remain = self.len - self.read; if remain == 0 { return Ok(BodyChunk::Done { trailers: None }); @@ -117,19 +119,19 @@ impl ContentLengthDecoder { let mut buf = buf_slot .take() - .ok_or_else(|| BodyErrorReason::CalledNextChunkAfterError.as_err())?; + .ok_or(BodyError::CalledNextChunkAfterError)?; if buf.is_empty() { buf.reserve()?; let res; (res, buf) = buf.read_into(usize::MAX, transport).await; - res.map_err(|e| BodyErrorReason::ErrorWhileReadingChunkData.with_cx(e))?; + res.map_err(BodyError::ErrorWhileReadingChunkData)?; } let chunk = buf .take_at_most(remain as usize) - .ok_or_else(|| BodyErrorReason::ClosedWhileReadingContentLength.as_err())?; + .ok_or(BodyError::ClosedWhileReadingContentLength)?; self.read += chunk.len() as u64; buf_slot.replace(buf); Ok(BodyChunk::Chunk(chunk.into())) @@ -145,11 +147,11 @@ impl ChunkedDecoder { &mut self, buf_slot: &mut Option, transport: &mut impl ReadOwned, - ) -> eyre::Result { + ) -> Result { loop { let mut buf = buf_slot .take() - .ok_or_else(|| BodyErrorReason::CalledNextChunkAfterError.as_err())?; + .ok_or(BodyError::CalledNextChunkAfterError)?; if let ChunkedDecoder::Done = self { buf_slot.replace(buf); @@ -159,21 +161,30 @@ impl ChunkedDecoder { } if let ChunkedDecoder::ReadingChunkHeader = self { - let (next_buf, chunk_size) = - read_and_parse(super::parse::chunk_size, transport, buf, 16) - .await - .map_err(|e| BodyErrorReason::InvalidChunkSize.with_cx(e))? - .ok_or_else(|| BodyErrorReason::ClosedWhileReadingChunkSize.as_err())?; + let (next_buf, chunk_size) = read_and_parse( + "Http1BodyChunk", + super::parse::chunk_size, + transport, + buf, + 16, + ) + .await + .map_err(|_| BodyError::InvalidChunkSize)? + .ok_or(BodyError::ClosedWhileReadingChunkSize)?; buf = next_buf; if chunk_size == 0 { // that's the final chunk, look for the final CRLF - let (next_buf, _) = read_and_parse(super::parse::crlf, transport, buf, 2) - .await - .map_err(|e| BodyErrorReason::InvalidChunkTerminator.with_cx(e))? - .ok_or_else(|| { - BodyErrorReason::ClosedWhileReadingChunkTerminator.as_err() - })?; + let (next_buf, _) = read_and_parse( + "Http1BodyChunkFinalTerminator", + super::parse::crlf, + transport, + buf, + 2, + ) + .await + .map_err(BodyError::InvalidChunkTerminator)? + .ok_or(BodyError::ClosedWhileReadingChunkTerminator)?; buf = next_buf; *self = ChunkedDecoder::Done; buf_slot.replace(buf); @@ -188,12 +199,16 @@ impl ChunkedDecoder { if let ChunkedDecoder::ReadingChunk { remain } = self { if *remain == 0 { // look for CRLF terminator - let (next_buf, _) = read_and_parse(super::parse::crlf, transport, buf, 2) - .await - .map_err(|e| BodyErrorReason::InvalidChunkTerminator.with_cx(e))? - .ok_or_else(|| { - BodyErrorReason::ClosedWhileReadingChunkTerminator.as_err() - })?; + let (next_buf, _) = read_and_parse( + "Http1BodyChunkTerminator", + super::parse::crlf, + transport, + buf, + 2, + ) + .await + .map_err(BodyError::InvalidChunkTerminator)? + .ok_or(BodyError::ClosedWhileReadingChunkTerminator)?; buf = next_buf; *self = ChunkedDecoder::ReadingChunkHeader; buf_slot.replace(buf); @@ -205,7 +220,7 @@ impl ChunkedDecoder { let res; (res, buf) = buf.read_into(*remain as usize, transport).await; - res.map_err(|e| BodyErrorReason::ErrorWhileReadingChunkData.with_cx(e))?; + res.map_err(BodyError::ErrorWhileReadingChunkData)?; } let chunk = buf.take_at_most(*remain as usize); @@ -216,7 +231,7 @@ impl ChunkedDecoder { return Ok(BodyChunk::Chunk(chunk.into())); } None => { - return Err(BodyErrorReason::ClosedWhileReadingChunkData.as_err().into()); + return Err(BodyError::ClosedWhileReadingChunkData); } } } else { @@ -236,20 +251,38 @@ pub enum BodyWriteMode { Chunked, // we set a length and are writing exactly the number of bytes we promised - ContentLength, + ContentLength(u64), // we didn't set a content-length and we're not doing chunked transfer // encoding, so we're not sending a body at all. Empty, } -pub(crate) async fn write_h1_body( +#[derive(thiserror::Error, Debug)] +pub enum WriteBodyError { + // Error from the `Body` impl itself + #[error("inner body error: {0}")] + InnerBodyError(OurBodyError), + + // BodyError + #[error("body error: {0}")] + BodyError(#[from] BodyError), +} + +pub(crate) async fn write_h1_body( transport: &mut impl WriteOwned, - body: &mut impl Body, + body: &mut B, mode: BodyWriteMode, -) -> eyre::Result<()> { +) -> Result<(), WriteBodyError> +where + B: Body, +{ loop { - match body.next_chunk().await? { + match body + .next_chunk() + .await + .map_err(WriteBodyError::InnerBodyError)? + { BodyChunk::Chunk(chunk) => write_h1_body_chunk(transport, chunk, mode).await?, BodyChunk::Done { .. } => { // TODO: check that we've sent what we announced in terms of @@ -267,7 +300,7 @@ pub(crate) async fn write_h1_body_chunk( transport: &mut impl WriteOwned, chunk: Piece, mode: BodyWriteMode, -) -> eyre::Result<()> { +) -> Result<(), BodyError> { match mode { BodyWriteMode::Chunked => { transport @@ -277,15 +310,17 @@ pub(crate) async fn write_h1_body_chunk( .followed_by(chunk) .followed_by("\r\n"), ) - .await?; + .await + .map_err(BodyError::WriteError)?; } - BodyWriteMode::ContentLength => { - transport.write_all_owned(chunk).await?; + BodyWriteMode::ContentLength(_) => { + transport + .write_all_owned(chunk) + .await + .map_err(BodyError::WriteError)?; } BodyWriteMode::Empty => { - return Err(BodyErrorReason::CalledWriteBodyChunkWhenNoBodyWasExpected - .as_err() - .into()); + return Err(BodyError::CalledWriteBodyChunkWhenNoBodyWasExpected); } } Ok(()) @@ -294,13 +329,16 @@ pub(crate) async fn write_h1_body_chunk( pub(crate) async fn write_h1_body_end( transport: &mut impl WriteOwned, mode: BodyWriteMode, -) -> eyre::Result<()> { +) -> Result<(), BodyError> { debug!(?mode, "writing h1 body end"); match mode { BodyWriteMode::Chunked => { - transport.write_all_owned("0\r\n\r\n").await?; + transport + .write_all_owned("0\r\n\r\n") + .await + .map_err(BodyError::WriteError)?; } - BodyWriteMode::ContentLength => { + BodyWriteMode::ContentLength(..) => { // nothing to do } BodyWriteMode::Empty => { diff --git a/crates/loona/src/h1/client.rs b/crates/loona/src/h1/client.rs index 8970db61..271eb2f2 100644 --- a/crates/loona/src/h1/client.rs +++ b/crates/loona/src/h1/client.rs @@ -1,8 +1,12 @@ -use eyre::Context; +use b_x::BX; use http::header; use tracing::debug; -use crate::{types::Request, util::read_and_parse, Body, HeadersExt, Response}; +use crate::{ + types::Request, + util::{read_and_parse, ReadAndParseError}, + Body, HeadersExt, Response, +}; use buffet::{ PieceList, RollMut, {ReadOwned, WriteOwned}, }; @@ -17,13 +21,42 @@ pub struct ClientConf {} #[allow(async_fn_in_trait)] // we never require Send pub trait ClientDriver { type Return; + type Error: std::error::Error + 'static; - async fn on_informational_response(&mut self, res: Response) -> eyre::Result<()>; + async fn on_informational_response(&mut self, res: Response) -> Result<(), Self::Error>; async fn on_final_response( self, res: Response, body: &mut impl Body, - ) -> eyre::Result; + ) -> Result; +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Http1ClientError { + #[error("An error occurred with the client driver: {0}")] + DriverError(#[source] DriverError), + + #[error("Could not write the request headers")] + WhileWritingRequestHeaders(#[source] std::io::Error), + + #[error("Could not read / receive the response headers")] + ErrorReadingResponseHeaders(#[from] ReadAndParseError), + + #[error("Server went away before sending response headers")] + ServerWentAwayBeforeSendingResponseHeaders, + + #[error("Allocation failed")] + Alloc(#[from] buffet::bufpool::Error), +} + +impl From> for BX +where + DriverError: std::error::Error + 'static, +{ + fn from(e: Http1ClientError) -> Self { + BX::from_err(e) + } } /// Perform an HTTP/1.1 request against an HTTP/1.1 server @@ -35,7 +68,7 @@ pub async fn request( mut req: Request, body: &mut impl Body, driver: D, -) -> eyre::Result<(Option<(R, W)>, D::Return)> +) -> Result<(Option<(R, W)>, D::Return), Http1ClientError> where R: ReadOwned, W: WriteOwned, @@ -48,7 +81,7 @@ where // directly to a `RollMut`, without going through `format!` machinery req.headers .insert(header::CONTENT_LENGTH, len.to_string().into_bytes().into()); - BodyWriteMode::ContentLength + BodyWriteMode::ContentLength(len) } None => BodyWriteMode::Chunked, }; @@ -56,11 +89,12 @@ where let mut buf = RollMut::alloc()?; let mut list = PieceList::default(); - encode_request(req, &mut list, &mut buf)?; + encode_request(req, &mut list, &mut buf) + .map_err(Http1ClientError::WhileWritingRequestHeaders)?; transport_w .writev_all_owned(list) .await - .wrap_err("writing request headers")?; + .map_err(Http1ClientError::WhileWritingRequestHeaders)?; // TODO: handle `expect: 100-continue` (don't start sending body until we get a // 100 response) @@ -71,11 +105,11 @@ where Err(err) => { // TODO: find way to report this error to the driver without // spawning, without ref-counting the driver, etc. - panic!("error writing request body: {err:?}"); + panic!("error writing request body: {err}"); } Ok(_) => { debug!("done writing request body"); - Ok::<_, eyre::Report>(transport_w) + Ok::<_, Http1ClientError>(transport_w) } } } @@ -84,6 +118,7 @@ where let recv_res_fut = { async move { let (buf, res) = read_and_parse( + "Http1Response", super::parse::response, &mut transport_r, buf, @@ -91,8 +126,8 @@ where 64 * 1024, ) .await - .map_err(|e| eyre::eyre!("error reading response headers from server: {e:?}"))? - .ok_or_else(|| eyre::eyre!("server went away before sending response headers"))?; + .map_err(Http1ClientError::ErrorReadingResponseHeaders)? + .ok_or(Http1ClientError::ServerWentAwayBeforeSendingResponseHeaders)?; debug!("client received response"); res.debug_print(); @@ -119,7 +154,10 @@ where let conn_close = res.headers.is_connection_close(); - let ret = driver.on_final_response(res, &mut res_body).await?; + let ret = driver + .on_final_response(res, &mut res_body) + .await + .map_err(Http1ClientError::DriverError)?; let transport_r = match (conn_close, res_body.into_inner()) { // can only re-use the body if conn_close is false and the body was fully draided diff --git a/crates/loona/src/h1/encode.rs b/crates/loona/src/h1/encode.rs index 9dba562a..354fddaa 100644 --- a/crates/loona/src/h1/encode.rs +++ b/crates/loona/src/h1/encode.rs @@ -1,11 +1,10 @@ use std::io::Write; -use eyre::Context; use http::{header, StatusCode, Version}; use crate::{ types::{Headers, Request, Response}, - Encoder, HeadersExt, + BodyError, Encoder, HeadersExt, }; use buffet::{Piece, PieceList, RollMut, WriteOwned}; @@ -15,7 +14,7 @@ pub(crate) fn encode_request( req: Request, list: &mut PieceList, out_scratch: &mut RollMut, -) -> eyre::Result<()> { +) -> Result<(), std::io::Error> { list.push_back(req.method.into_chunk()); list.push_back(" "); @@ -26,7 +25,10 @@ pub(crate) fn encode_request( match req.version { Version::HTTP_10 => list.push_back(" HTTP/1.0\r\n"), Version::HTTP_11 => list.push_back(" HTTP/1.1\r\n"), - _ => return Err(eyre::eyre!("unsupported HTTP version {:?}", req.version)), + _ => panic!( + "passed unsupported HTTP version to HTTP/1.1 request encoder {:?}", + req.version + ), } // TODO: if `host` isn't set, set from request uri? which should @@ -36,11 +38,14 @@ pub(crate) fn encode_request( Ok(()) } -fn encode_response(res: Response, list: &mut PieceList) -> eyre::Result<()> { +fn encode_response(res: Response, list: &mut PieceList) -> Result<(), std::io::Error> { match res.version { Version::HTTP_10 => list.push_back(&b"HTTP/1.0 "[..]), Version::HTTP_11 => list.push_back(&b"HTTP/1.1 "[..]), - _ => return Err(eyre::eyre!("unsupported HTTP version {:?}", res.version)), + _ => panic!( + "passed unsupported HTTP version to HTTP/1.1 response encoder {:?}", + res.version + ), } list.push_back(encode_status_code(res.status)); @@ -52,7 +57,7 @@ fn encode_response(res: Response, list: &mut PieceList) -> eyre::Result<()> { Ok(()) } -pub(crate) fn encode_headers(headers: Headers, list: &mut PieceList) -> eyre::Result<()> { +pub(crate) fn encode_headers(headers: Headers, list: &mut PieceList) -> Result<(), std::io::Error> { let mut last_header_name = None; for (name, value) in headers { match name { @@ -144,19 +149,19 @@ const CODE_DIGITS: &str = "\ 960961962963964965966967968969970971972973974975976977978979\ 980981982983984985986987988989990991992993994995996997998999"; -pub struct H1Encoder +pub struct H1Encoder where - T: WriteOwned, + OurWriteOwned: WriteOwned, { - pub(crate) transport_w: T, + pub(crate) transport_w: OurWriteOwned, mode: BodyWriteMode, } -impl H1Encoder +impl H1Encoder where - T: WriteOwned, + OurWriteOwned: WriteOwned, { - pub fn new(transport_w: T) -> Self { + pub fn new(transport_w: OurWriteOwned) -> Self { Self { transport_w, mode: BodyWriteMode::Empty, @@ -164,26 +169,42 @@ where } } +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum H1EncoderError { + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + #[error("Wrong state: expected {expected:?}, actual {actual:?}")] + WrongState { + expected: BodyWriteMode, + actual: BodyWriteMode, + }, + #[error("Body error: {0}")] + BodyError(#[from] BodyError), +} + +impl AsRef for H1EncoderError { + fn as_ref(&self) -> &(dyn std::error::Error + 'static) { + self + } +} + impl Encoder for H1Encoder where T: WriteOwned, { - async fn write_response(&mut self, mut res: Response) -> eyre::Result<()> { - // TODO: set BodyWriteMode here (and take it out of the Encoder trait, - // h2 doesn't need it) - - if !res.status.is_informational() { - // after this, we expect a body — time to determine the body write mode - if !res.means_empty_body() { - self.mode = match res.headers.content_length() { - Some(0) => BodyWriteMode::Empty, - Some(_) => BodyWriteMode::ContentLength, - None => { - res.headers - .insert(header::TRANSFER_ENCODING, "chunked".into()); - BodyWriteMode::Chunked - } - }; + type Error = H1EncoderError; + + async fn write_response(&mut self, mut res: Response) -> Result<(), Self::Error> { + if !res.status.is_informational() && !res.means_empty_body() { + self.mode = match res.headers.content_length() { + Some(0) => BodyWriteMode::Empty, + Some(length) => BodyWriteMode::ContentLength(length), + None => { + res.headers + .insert(header::TRANSFER_ENCODING, "chunked".into()); + BodyWriteMode::Chunked + } }; } @@ -193,29 +214,34 @@ where self.transport_w .writev_all_owned(list) .await - .wrap_err("writing response headers upstream")?; + .map_err(H1EncoderError::from)?; Ok(()) } - // TODO: move `mode` into `H1Encoder`? we don't need it for h2 - async fn write_body_chunk(&mut self, chunk: Piece) -> eyre::Result<()> { - write_h1_body_chunk(&mut self.transport_w, chunk, self.mode).await + async fn write_body_chunk(&mut self, chunk: Piece) -> Result<(), Self::Error> { + // note: we don't check content length here, because it's done by the Responder, + // note by encoders. + + write_h1_body_chunk(&mut self.transport_w, chunk, self.mode) + .await + .map_err(H1EncoderError::from) } - async fn write_body_end(&mut self) -> eyre::Result<()> { - write_h1_body_end(&mut self.transport_w, self.mode).await + async fn write_body_end(&mut self) -> Result<(), Self::Error> { + write_h1_body_end(&mut self.transport_w, self.mode) + .await + .map_err(H1EncoderError::from) } - async fn write_trailers(&mut self, trailers: Box) -> eyre::Result<()> { - // TODO: check all preconditions + async fn write_trailers(&mut self, trailers: Box) -> Result<(), Self::Error> { let mut list = PieceList::default(); encode_headers(*trailers, &mut list)?; self.transport_w .writev_all_owned(list) .await - .wrap_err("writing response headers upstream")?; + .map_err(H1EncoderError::from)?; Ok(()) } diff --git a/crates/loona/src/h1/server.rs b/crates/loona/src/h1/server.rs index e9f68f87..45c58498 100644 --- a/crates/loona/src/h1/server.rs +++ b/crates/loona/src/h1/server.rs @@ -1,12 +1,12 @@ use std::rc::Rc; -use eyre::Context; use tracing::debug; use crate::{ + error::ServeError, h1::body::{H1Body, H1BodyKind}, - util::{read_and_parse, SemanticError}, - HeadersExt, Responder, ServerDriver, + util::{read_and_parse, ReadAndParseError}, + HeadersExt, Responder, ServeOutcome, ServerDriver, }; use buffet::{ReadOwned, RollMut, WriteOwned}; @@ -33,24 +33,21 @@ impl Default for ServerConf { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ServeOutcome { - ClientRequestedConnectionClose, - ServerRequestedConnectionClose, - ClientClosedConnectionBetweenRequests, - // TODO: return buffer there so we can see what they did write? - ClientDidntSpeakHttp11, -} - -pub async fn serve( - (mut transport_r, mut transport_w): (impl ReadOwned, impl WriteOwned), +pub async fn serve( + (mut transport_r, mut transport_w): (OurReadOwned, OurWriteOwned), conf: Rc, mut client_buf: RollMut, - driver: impl ServerDriver, -) -> eyre::Result { + driver: OurDriver, +) -> Result> +where + OurDriver: ServerDriver>, + OurReadOwned: ReadOwned, + OurWriteOwned: WriteOwned, +{ loop { let req; (client_buf, req) = match read_and_parse( + "Http1Request", super::parse::request, &mut transport_r, client_buf, @@ -65,17 +62,22 @@ pub async fn serve( return Ok(ServeOutcome::ClientClosedConnectionBetweenRequests); } }, - Err(e) => { - if let Some(se) = e.downcast_ref::() { + Err(e) => match e { + ReadAndParseError::BufferLimitReachedWhileParsing { limit } => { + debug!("request headers larger than {limit} bytes, replying with 431 and hanging up"); + let reply = b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n"; transport_w - .write_all_owned(se.as_http_response()) + .write_all_owned(reply) .await - .wrap_err("writing error response downstream")?; - } + .map_err(ServeError::DownstreamWrite)?; - debug!(?e, "error reading request header from downstream"); - return Ok(ServeOutcome::ClientDidntSpeakHttp11); - } + return Ok(ServeOutcome::RequestHeadersTooLargeOnHttp1Conn); + } + _ => { + debug!(?e, "error reading request header from downstream"); + return Ok(ServeOutcome::ClientDidntSpeakHttp11); + } + }, }; debug!("got request {req:?}"); @@ -98,14 +100,14 @@ pub async fn serve( let resp = driver .handle(req, &mut req_body, responder) .await - .wrap_err("handling request")?; + .map_err(ServeError::Driver)?; // TODO: if we sent `connection: close` we should close now transport_w = resp.into_inner().transport_w; (client_buf, transport_r) = req_body .into_inner() - .ok_or_else(|| eyre::eyre!("request body not drained, have to close connection"))?; + .ok_or(ServeError::ResponseHandlerBodyNotDrained)?; if connection_close { debug!("client requested connection close"); diff --git a/crates/loona/src/h2/body.rs b/crates/loona/src/h2/body.rs index 129811f5..e636cb36 100644 --- a/crates/loona/src/h2/body.rs +++ b/crates/loona/src/h2/body.rs @@ -2,7 +2,7 @@ use core::fmt; use tokio::sync::mpsc; -use crate::{Body, BodyChunk, Headers}; +use crate::{error::NeverError, Body, BodyChunk, Headers}; use buffet::Piece; use super::types::H2StreamError; @@ -21,7 +21,7 @@ pub(crate) enum ChunkPosition { } pub(crate) struct StreamIncoming { - tx: mpsc::Sender, + tx: mpsc::Sender, // total bytes received, which we keep track of, because if the client // announces a content-length and sends fewer or more bytes, we will @@ -34,14 +34,21 @@ pub(crate) struct StreamIncoming { pub(crate) capacity: i64, } +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum StreamIncomingError { + #[error("stream reset")] + StreamReset, +} + impl StreamIncoming { pub(crate) fn new( initial_window_size: u32, content_length: Option, - piece_tx: mpsc::Sender>, + tx: mpsc::Sender, ) -> Self { Self { - tx: piece_tx, + tx, total_received: 0, content_length, capacity: initial_window_size as i64, @@ -108,25 +115,36 @@ impl StreamIncoming { Ok(()) } - pub(crate) async fn send_error(&mut self, err: eyre::Report) { + pub(crate) async fn send_error(&mut self, err: StreamIncomingError) { let _ = self.tx.send(Err(err)).await; } } -// FIXME: don't use eyre, do proper error handling -pub(crate) type IncomingMessagesResult = eyre::Result; +pub(crate) type IncomingMessageResult = Result; #[derive(Debug)] pub(crate) struct H2Body { pub(crate) content_length: Option, - pub(crate) eof: bool, + pub(crate) rx: mpsc::Receiver, +} - // TODO: more specific error handling - pub(crate) rx: mpsc::Receiver, +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub(crate) enum H2BodyError { + #[error("Stream reset")] + StreamReset, +} + +impl AsRef for H2BodyError { + fn as_ref(&self) -> &(dyn std::error::Error + 'static) { + self + } } impl Body for H2Body { + type Error = H2BodyError; + fn content_len(&self) -> Option { self.content_length } @@ -135,21 +153,21 @@ impl Body for H2Body { self.eof } - async fn next_chunk(&mut self) -> eyre::Result { + async fn next_chunk(&mut self) -> Result { let chunk = if self.eof { BodyChunk::Done { trailers: None } } else { match self.rx.recv().await { - Some(msg) => match msg? { - IncomingMessage::Piece(piece) => BodyChunk::Chunk(piece), - IncomingMessage::Trailers(trailers) => { + Some(msg) => match msg { + Ok(IncomingMessage::Piece(piece)) => BodyChunk::Chunk(piece), + Ok(IncomingMessage::Trailers(trailers)) => { self.eof = true; BodyChunk::Done { trailers: Some(trailers), } } + Err(StreamIncomingError::StreamReset) => return Err(H2BodyError::StreamReset), }, - // TODO: handle trailers None => { self.eof = true; BodyChunk::Done { trailers: None } @@ -194,15 +212,17 @@ impl SinglePieceBody { } impl Body for SinglePieceBody { + type Error = NeverError; + fn content_len(&self) -> Option { - self.piece.as_ref().map(|piece| piece.len() as u64) + Some(self.content_len) } fn eof(&self) -> bool { self.piece.is_none() } - async fn next_chunk(&mut self) -> eyre::Result { + async fn next_chunk(&mut self) -> Result { tracing::trace!( has_piece = %self.piece.is_some(), "SinglePieceBody::next_chunk"); if let Some(piece) = self.piece.take() { Ok(BodyChunk::Chunk(piece)) diff --git a/crates/loona/src/h2/encode.rs b/crates/loona/src/h2/encode.rs index 87bea652..4ee69c64 100644 --- a/crates/loona/src/h2/encode.rs +++ b/crates/loona/src/h2/encode.rs @@ -7,15 +7,16 @@ use super::types::{H2Event, H2EventPayload}; use crate::{Encoder, Response}; use loona_h2::StreamId; -#[derive(Debug, PartialEq, Eq)] -pub(crate) enum EncoderState { +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[non_exhaustive] +pub enum EncoderState { ExpectResponseHeaders, ExpectResponseBody, ResponseDone, } /// Encodes HTTP/2 responses and bodies -pub(crate) struct H2Encoder { +pub struct H2Encoder { stream_id: StreamId, tx: mpsc::Sender, state: EncoderState, @@ -37,25 +38,51 @@ impl H2Encoder { } } - async fn send(&self, payload: H2EventPayload) -> eyre::Result<()> { + async fn send(&self, payload: H2EventPayload) -> Result<(), H2EncoderError> { self.tx .send(self.event(payload)) .await - .map_err(|_| eyre::eyre!("could not send event to h2 connection handler"))?; + .map_err(|_| H2EncoderError::StreamReset)?; Ok(()) } } +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum H2EncoderError { + /// The encoder is in the wrong state + #[error("Wrong state: expected {expected:?}, actual {actual:?}")] + WrongState { + expected: EncoderState, + actual: EncoderState, + }, + + #[error("Stream reset")] + StreamReset, +} + +impl AsRef for H2EncoderError { + fn as_ref(&self) -> &(dyn std::error::Error + 'static) { + self + } +} + impl Encoder for H2Encoder { - async fn write_response(&mut self, res: Response) -> eyre::Result<()> { - // TODO: don't panic here + type Error = H2EncoderError; + + async fn write_response(&mut self, res: Response) -> Result<(), Self::Error> { + // FIXME: HTTP/2 _does_ support informational responses, cf. https://github.com/bearcove/loona/issues/190 assert!( !res.status.is_informational(), "http/2 does not support informational responses" ); - // TODO: don't panic here - assert_eq!(self.state, EncoderState::ExpectResponseHeaders); + if self.state != EncoderState::ExpectResponseHeaders { + return Err(H2EncoderError::WrongState { + expected: EncoderState::ExpectResponseHeaders, + actual: self.state, + }); + } self.send(H2EventPayload::Headers(res)).await?; self.state = EncoderState::ExpectResponseBody; @@ -64,16 +91,26 @@ impl Encoder for H2Encoder { } // TODO: BodyWriteMode is not relevant for h2 - async fn write_body_chunk(&mut self, chunk: Piece) -> eyre::Result<()> { - assert!(matches!(self.state, EncoderState::ExpectResponseBody)); + async fn write_body_chunk(&mut self, chunk: Piece) -> Result<(), Self::Error> { + if self.state != EncoderState::ExpectResponseBody { + return Err(H2EncoderError::WrongState { + expected: EncoderState::ExpectResponseBody, + actual: self.state, + }); + } self.send(H2EventPayload::BodyChunk(chunk)).await?; Ok(()) } // TODO: BodyWriteMode is not relevant for h2 - async fn write_body_end(&mut self) -> eyre::Result<()> { - assert!(matches!(self.state, EncoderState::ExpectResponseBody)); + async fn write_body_end(&mut self) -> Result<(), Self::Error> { + if self.state != EncoderState::ExpectResponseBody { + return Err(H2EncoderError::WrongState { + expected: EncoderState::ExpectResponseBody, + actual: self.state, + }); + } self.send(H2EventPayload::BodyEnd).await?; self.state = EncoderState::ResponseDone; @@ -82,8 +119,13 @@ impl Encoder for H2Encoder { } // TODO: handle trailers - async fn write_trailers(&mut self, _trailers: Box) -> eyre::Result<()> { - assert!(matches!(self.state, EncoderState::ResponseDone)); + async fn write_trailers(&mut self, _trailers: Box) -> Result<(), Self::Error> { + if self.state != EncoderState::ResponseDone { + return Err(H2EncoderError::WrongState { + expected: EncoderState::ResponseDone, + actual: self.state, + }); + } todo!("write trailers") } diff --git a/crates/loona/src/h2/mod.rs b/crates/loona/src/h2/mod.rs index f8f7540b..5d4208a2 100644 --- a/crates/loona/src/h2/mod.rs +++ b/crates/loona/src/h2/mod.rs @@ -6,4 +6,6 @@ pub use server::*; mod body; mod encode; -mod types; +pub use encode::H2EncoderError; + +pub(crate) mod types; diff --git a/crates/loona/src/h2/server.rs b/crates/loona/src/h2/server.rs index d1bf4e9e..3585954b 100644 --- a/crates/loona/src/h2/server.rs +++ b/crates/loona/src/h2/server.rs @@ -8,7 +8,6 @@ use std::{ use buffet::{Piece, PieceList, PieceStr, ReadOwned, Roll, RollMut, WriteOwned}; use byteorder::{BigEndian, WriteBytesExt}; -use eyre::Context; use http::{ header, uri::{Authority, PathAndQuery, Scheme}, @@ -25,16 +24,17 @@ use tokio::sync::mpsc; use tracing::{debug, trace}; use crate::{ + error::ServeError, h2::{ - body::{H2Body, IncomingMessagesResult, StreamIncoming}, + body::{H2Body, IncomingMessageResult, StreamIncoming, StreamIncomingError}, encode::H2Encoder, types::{ BodyOutgoing, ConnState, H2ConnectionError, H2Event, H2EventPayload, H2RequestError, H2StreamError, HeadersOrTrailers, HeadersOutgoing, StreamOutgoing, StreamState, }, }, - util::read_and_parse, - Headers, Method, Request, Responder, ServerDriver, + util::{read_and_parse, ReadAndParseError}, + Headers, Method, Request, Responder, ResponderOrBodyError, ServeOutcome, ServerDriver, }; use super::{ @@ -57,16 +57,22 @@ impl Default for ServerConf { } } -pub async fn serve( - (transport_r, transport_w): (impl ReadOwned, impl WriteOwned), +pub async fn serve( + (transport_r, transport_w): (OurReadOwned, OurWriteOwned), conf: Rc, client_buf: RollMut, - driver: Rc, -) -> eyre::Result<()> { + driver: Rc, +) -> Result<(), ServeError> +where + OurDriver: ServerDriver + 'static, + OurReadOwned: ReadOwned, + OurWriteOwned: WriteOwned, +{ let mut state = ConnState::default(); state.self_settings.max_concurrent_streams = conf.max_streams; - let mut cx = ServerContext::new(driver.clone(), state, transport_w)?; + let mut cx = + ServerContext::new(driver.clone(), state, transport_w).map_err(ServeError::Alloc)?; cx.work(client_buf, transport_r).await?; debug!("finished serving"); @@ -74,8 +80,12 @@ pub async fn serve( } /// Reads and processes h2 frames from the client. -pub(crate) struct ServerContext { - driver: Rc, +pub(crate) struct ServerContext +where + OurDriver: ServerDriver + 'static, + OurWriter: WriteOwned, +{ + driver: Rc, state: ConnState, hpack_dec: loona_hpack::Decoder<'static>, @@ -87,14 +97,22 @@ pub(crate) struct ServerContext { /// TODO: encapsulate into a framer, don't /// allow direct access from context methods - transport_w: W, + transport_w: OurWriter, ev_tx: mpsc::Sender, ev_rx: mpsc::Receiver, } -impl ServerContext { - pub(crate) fn new(driver: Rc, state: ConnState, transport_w: W) -> eyre::Result { +impl ServerContext +where + OurDriver: ServerDriver + 'static, + OurWriteOwned: WriteOwned, +{ + pub(crate) fn new( + driver: Rc, + state: ConnState, + transport_w: OurWriteOwned, + ) -> Result { let mut hpack_dec = loona_hpack::Decoder::new(); hpack_dec .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); @@ -121,21 +139,22 @@ impl ServerContext { &mut self, mut client_buf: RollMut, mut transport_r: impl ReadOwned, - ) -> eyre::Result<()> { + ) -> Result> { // first read the preface { (client_buf, _) = match read_and_parse( + "Http2Preface", parse::preface, &mut transport_r, client_buf, parse::PREFACE.len(), ) - .await? + .await + .map_err(H2ConnectionError::ReadAndParse)? { Some((client_buf, frame)) => (client_buf, frame), None => { - debug!("h2 client closed connection before sending preface"); - return Ok(()); + return Ok(ServeOutcome::ClientDidntSpeakHttp2); } }; } @@ -156,7 +175,8 @@ impl ServerContext { (Setting::MaxFrameSize, s.max_frame_size), (Setting::MaxHeaderListSize, s.max_header_list_size), ]) - .into_piece(&mut self.out_scratch)? + .into_piece(&mut self.out_scratch) + .map_err(ServeError::DownstreamWrite)? }; let frame = Frame::new( FrameType::Settings(Default::default()), @@ -192,11 +212,11 @@ impl ServerContext { if let Err(e) = res { match e { - H2ConnectionError::ReadError(e) => { + H2ConnectionError::ReadAndParse(e) => { let mut should_ignore_err = false; // if this is a connection reset and we've sent a goaway, ignore it - if let Some(io_error) = e.root_cause().downcast_ref::() { + if let ReadAndParseError::ReadError(io_error) = &e { if io_error.kind() == std::io::ErrorKind::ConnectionReset { should_ignore_err = true; } @@ -204,7 +224,7 @@ impl ServerContext { debug!(%should_ignore_err, "deciding whether or not to propagate deframer error"); if !should_ignore_err { - return Err(e.wrap_err("h2 io")); + return Err(H2ConnectionError::ReadAndParse(e).into()); } }, e => { @@ -218,7 +238,7 @@ impl ServerContext { // what about the GOAWAY? debug!("h2 process task finished with error: {e}"); - return Err(e).wrap_err("h2 process"); + return Err(e.into()); } } res = &mut process_task => { @@ -240,6 +260,7 @@ impl ServerContext { // TODO: figure out graceful shutdown: this would involve sending a goaway // before this point, and processing all the connections we've accepted + // FIXME: we have a GoAway encoder, why are we doing this manually debug!(last_stream_id = %self.state.last_stream_id, ?error_code, "Sending GoAway"); let payload = self.out_scratch @@ -252,10 +273,12 @@ impl ServerContext { })?; let frame = Frame::new(FrameType::GoAway, StreamId::CONNECTION); - self.write_frame(frame, PieceList::single(payload)).await?; + self.write_frame(frame, PieceList::single(payload)) + .await + .map_err(ServeError::H2ConnectionError)?; } - Ok(()) + Ok(ServeOutcome::SuccessfulHttp2GracefulShutdown) } async fn deframe_loop( @@ -269,6 +292,7 @@ impl ServerContext { let frame; trace!("Reading frame... Buffer length: {}", client_buf.len()); let frame_res = read_and_parse( + "Http2Frame", Frame::parse, &mut transport_r, client_buf, @@ -278,7 +302,7 @@ impl ServerContext { let maybe_frame = match frame_res { Ok(inner) => inner, - Err(e) => return Err(H2ConnectionError::ReadError(e)), + Err(e) => return Err(H2ConnectionError::ReadAndParse(e)), }; (client_buf, frame) = match maybe_frame { Some((client_buf, frame)) => (client_buf, frame), @@ -309,12 +333,14 @@ impl ServerContext { ); let mut payload; (client_buf, payload) = match read_and_parse( + "FramePayload", nom::bytes::streaming::take(frame.len as usize), &mut transport_r, client_buf, frame.len as usize, ) - .await? + .await + .map_err(H2ConnectionError::ReadAndParse)? { Some((client_buf, payload)) => (client_buf, payload), None => { @@ -782,7 +808,7 @@ impl ServerContext { debug!(?frame, ">"); let frame_roll = frame .into_piece(&mut self.out_scratch) - .map_err(|e| eyre::eyre!(e))?; + .map_err(H2ConnectionError::WriteError)?; if payload.is_empty() { trace!("Writing frame without payload"); @@ -871,9 +897,11 @@ impl ServerContext { FrameType::Headers(flags) => { if flags.contains(HeadersFlags::Priority) { let pri_spec; - (payload, pri_spec) = PrioritySpec::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + (payload, pri_spec) = PrioritySpec::parse(payload).finish().map_err(|_| { + H2ConnectionError::ReadAndParse(ReadAndParseError::ParsingError { + parser: "PrioritySpec", + }) + })?; debug!(exclusive = %pri_spec.exclusive, stream_dependency = ?pri_spec.stream_dependency, weight = %pri_spec.weight, "received priority, exclusive"); if pri_spec.stream_dependency == frame.stream_id { @@ -1008,7 +1036,13 @@ impl ServerContext { }, &mut SinglePieceBody::new(e.message), ) - .await?; + .await + .map_err(|e| match e { + ResponderOrBodyError::Responder(e) => e, + ResponderOrBodyError::Body(_) => { + unreachable!("SinglePieceBody's error is Infallible") + } + })?; // don't even store the stream state anywhere, just record the last // stream id since we technically processed the request? maybe? @@ -1071,9 +1105,7 @@ impl ServerContext { match ss { StreamState::Open { mut incoming, .. } | StreamState::HalfClosedLocal { mut incoming, .. } => { - incoming - .send_error(eyre::eyre!("Received RST_STREAM from peer")) - .await; + incoming.send_error(StreamIncomingError::StreamReset).await; } StreamState::HalfClosedRemote { .. } => { // good @@ -1205,9 +1237,11 @@ impl ServerContext { }); } - let (_, update) = WindowUpdate::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + let (_, update) = WindowUpdate::parse(payload).finish().map_err(|_| { + H2ConnectionError::ReadAndParse(ReadAndParseError::ParsingError { + parser: "WindowUpdate", + }) + })?; debug!(?update, "Received window update"); if update.increment == 0 { @@ -1767,7 +1801,7 @@ impl ServerContext { let responder = Responder::new(H2Encoder::new(stream_id, self.ev_tx.clone())); - let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? + let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? let req_body = H2Body { content_length, diff --git a/crates/loona/src/h2/types.rs b/crates/loona/src/h2/types.rs index a112af3c..afc3b23f 100644 --- a/crates/loona/src/h2/types.rs +++ b/crates/loona/src/h2/types.rs @@ -8,9 +8,9 @@ use http::StatusCode; use loona_hpack::decoder::DecoderError; use tokio::sync::Notify; -use crate::Response; +use crate::{util::ReadAndParseError, ResponderError, Response}; -use super::body::StreamIncoming; +use super::{body::StreamIncoming, encode::H2EncoderError}; use loona_h2::{FrameType, KnownErrorCode, Settings, SettingsError, StreamId}; pub(crate) struct ConnState { @@ -311,7 +311,8 @@ impl fmt::Debug for H2RequestError { } #[derive(Debug, thiserror::Error)] -pub(crate) enum H2ConnectionError { +#[non_exhaustive] +pub enum H2ConnectionError { #[error("frame too large: {frame_type:?} frame of size {frame_size} exceeds max frame size of {max_frame_size}")] FrameTooLarge { frame_type: FrameType, @@ -374,15 +375,15 @@ pub(crate) enum H2ConnectionError { #[error("stream-specific frame {frame_type:?} sent to stream ID 0 (connection-wide)")] StreamSpecificFrameToConnection { frame_type: FrameType }, - #[error("other error: {0:?}")] - Internal(#[from] eyre::Report), - #[error("error reading/parsing H2 frame: {0:?}")] - ReadError(eyre::Report), + ReadAndParse(ReadAndParseError), #[error("error writing H2 frame: {0:?}")] WriteError(std::io::Error), + #[error("H2 responder error: {0:?}")] + ResponderError(#[from] ResponderError), + #[error("received rst frame for unknown stream")] RstStreamForUnknownStream { stream_id: StreamId }, @@ -445,8 +446,6 @@ impl H2ConnectionError { H2ConnectionError::HpackDecodingError(_) => KnownErrorCode::CompressionError, // stream closed error H2ConnectionError::StreamClosed { .. } => KnownErrorCode::StreamClosed, - // internal errors - H2ConnectionError::Internal(_) => KnownErrorCode::InternalError, // protocol errors H2ConnectionError::PaddedFrameTooShort { .. } => KnownErrorCode::ProtocolError, H2ConnectionError::StreamSpecificFrameToConnection { .. } => { @@ -458,6 +457,7 @@ impl H2ConnectionError { } #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub(crate) enum H2StreamError { #[error("received {data_length} bytes in data frames but content-length announced {content_length} bytes")] DataLengthDoesNotMatchContentLength { diff --git a/crates/loona/src/lib.rs b/crates/loona/src/lib.rs index 6df208d9..dda09519 100644 --- a/crates/loona/src/lib.rs +++ b/crates/loona/src/lib.rs @@ -1,6 +1,6 @@ +mod types; mod util; -mod types; pub use types::*; pub mod h1; @@ -14,12 +14,19 @@ pub use buffet; /// re-exported so consumers can use whatever forked version we use pub use http; +pub mod error; + #[allow(async_fn_in_trait)] // we never require Send -pub trait ServerDriver { - async fn handle( +pub trait ServerDriver +where + OurEncoder: Encoder, +{ + type Error: std::error::Error + 'static; + + async fn handle( &self, req: Request, req_body: &mut impl Body, - respond: Responder, - ) -> eyre::Result>; + respond: Responder, + ) -> Result, Self::Error>; } diff --git a/crates/loona/src/responder.rs b/crates/loona/src/responder.rs index 893239b8..b3177de1 100644 --- a/crates/loona/src/responder.rs +++ b/crates/loona/src/responder.rs @@ -1,5 +1,6 @@ +use b_x::BX; use buffet::Piece; -use http::header; +use http::{header, StatusCode}; use crate::{Body, BodyChunk, Headers, HeadersExt, Response}; @@ -17,20 +18,56 @@ impl ResponseState for ExpectResponseBody {} pub struct ResponseDone; impl ResponseState for ResponseDone {} -pub struct Responder +#[non_exhaustive] +#[derive(thiserror::Error, Debug)] +pub enum ResponderError { + #[error("interim response must have status code 1xx, got {actual}")] + InterimResponseMustHaveStatusCode1xx { actual: StatusCode }, + + #[error("final response must have status code >= 200, got {actual}")] + FinalResponseMustHaveStatusCodeGreaterThanOrEqualTo200 { actual: StatusCode }, + + #[error( + "body length does not match announced content length: actual {actual}, expected {expected}" + )] + BodyLengthDoesNotMatchAnnouncedContentLength { actual: u64, expected: u64 }, + + #[error("encoder error: {0}")] + EncoderError(#[from] EncoderError), +} + +impl From> for BX where - E: Encoder, - S: ResponseState, + EncoderError: std::error::Error + 'static, { - encoder: E, - state: S, + fn from(e: ResponderError) -> Self { + BX::from_err(e) + } } -impl Responder +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum ResponderOrBodyError { + #[error("Responder error: {0}")] + Responder(ResponderError), + #[error("Body error: {0}")] + Body(#[from] BodyError), +} + +pub struct Responder where - E: Encoder, + OurEncoder: Encoder, + OurResponseState: ResponseState, { - pub fn new(encoder: E) -> Self { + encoder: OurEncoder, + state: OurResponseState, +} + +impl Responder +where + OurEncoder: Encoder, +{ + pub fn new(encoder: OurEncoder) -> Self { Self { encoder, state: ExpectResponseHeaders, @@ -39,12 +76,20 @@ where /// Send an informational status code, cf. /// Errors out if the response status is not 1xx - pub async fn write_interim_response(&mut self, res: Response) -> eyre::Result<()> { + pub async fn write_interim_response( + &mut self, + res: Response, + ) -> Result<(), ResponderError> { if !res.status.is_informational() { - return Err(eyre::eyre!("interim response must have status code 1xx")); + return Err(ResponderError::InterimResponseMustHaveStatusCode1xx { + actual: res.status, + }); } - self.encoder.write_response(res).await?; + self.encoder + .write_response(res) + .await + .map_err(ResponderError::EncoderError)?; Ok(()) } @@ -52,11 +97,18 @@ where mut self, res: Response, announced_content_length: Option, - ) -> eyre::Result> { + ) -> Result, ResponderError> { if res.status.is_informational() { - return Err(eyre::eyre!("final response must have status code >= 200")); + return Err( + ResponderError::FinalResponseMustHaveStatusCodeGreaterThanOrEqualTo200 { + actual: res.status, + }, + ); } - self.encoder.write_response(res).await?; + self.encoder + .write_response(res) + .await + .map_err(ResponderError::EncoderError)?; Ok(Responder { state: ExpectResponseBody { announced_content_length, @@ -72,7 +124,7 @@ where pub async fn write_final_response( self, res: Response, - ) -> eyre::Result> { + ) -> ResponderResult, OurEncoder::Error> { let announced_content_length = res.headers.content_length(); self.write_final_response_internal(res, announced_content_length) .await @@ -80,11 +132,17 @@ where /// Writes a response with the given body. Sets `content-length` or /// `transfer-encoding` as needed. - pub async fn write_final_response_with_body( + pub async fn write_final_response_with_body( self, mut res: Response, - body: &mut impl Body, - ) -> eyre::Result> { + body: &mut TheirBody, + ) -> Result< + Responder, + ResponderOrBodyError, + > + where + TheirBody: Body, + { if let Some(clen) = body.content_len() { res.headers .entry(header::CONTENT_LENGTH) @@ -101,15 +159,27 @@ where }); } - let mut this = self.write_final_response(res).await?; + let mut this = self + .write_final_response(res) + .await + .map_err(ResponderOrBodyError::Responder)?; loop { - match body.next_chunk().await? { + match body + .next_chunk() + .await + .map_err(ResponderOrBodyError::Body)? + { BodyChunk::Chunk(chunk) => { - this.write_chunk(chunk).await?; + this.write_chunk(chunk) + .await + .map_err(ResponderOrBodyError::Responder)?; } BodyChunk::Done { trailers } => { - return this.finish_body(trailers).await; + return this + .finish_body(trailers) + .await + .map_err(ResponderOrBodyError::Responder); } } } @@ -123,9 +193,12 @@ where /// Send a response body chunk. Errors out if sending more than the /// announced content-length. #[inline] - pub async fn write_chunk(&mut self, chunk: Piece) -> eyre::Result<()> { + pub async fn write_chunk(&mut self, chunk: Piece) -> ResponderResult<(), E::Error> { self.state.bytes_written += chunk.len() as u64; - self.encoder.write_body_chunk(chunk).await + self.encoder + .write_body_chunk(chunk) + .await + .map_err(ResponderError::EncoderError) } /// Finish the body, with optional trailers, cf. @@ -137,19 +210,27 @@ where pub async fn finish_body( mut self, trailers: Option>, - ) -> eyre::Result> { + ) -> ResponderResult, E::Error> { if let Some(announced_content_length) = self.state.announced_content_length { if self.state.bytes_written != announced_content_length { - eyre::bail!( - "content-length mismatch: announced {announced_content_length}, wrote {}", - self.state.bytes_written + return Err( + ResponderError::BodyLengthDoesNotMatchAnnouncedContentLength { + actual: self.state.bytes_written, + expected: announced_content_length, + }, ); } } - self.encoder.write_body_end().await?; + self.encoder + .write_body_end() + .await + .map_err(ResponderError::EncoderError)?; if let Some(trailers) = trailers { - self.encoder.write_trailers(trailers).await?; + self.encoder + .write_trailers(trailers) + .await + .map_err(ResponderError::EncoderError)?; } Ok(Responder { @@ -168,12 +249,18 @@ where } } +pub type ResponderResult = Result>; + #[allow(async_fn_in_trait)] // we never require Send pub trait Encoder { - async fn write_response(&mut self, res: Response) -> eyre::Result<()>; - async fn write_body_chunk(&mut self, chunk: Piece) -> eyre::Result<()>; - async fn write_body_end(&mut self) -> eyre::Result<()>; - async fn write_trailers(&mut self, trailers: Box) -> eyre::Result<()>; + type Error: std::error::Error + 'static; + + async fn write_response(&mut self, res: Response) -> Result<(), Self::Error>; + /// Note: encoders do not have a duty to check for matching content-length: + /// the responder takes care of that for HTTP/1.1 and HTTP/2 + async fn write_body_chunk(&mut self, chunk: Piece) -> Result<(), Self::Error>; + async fn write_body_end(&mut self) -> Result<(), Self::Error>; + async fn write_trailers(&mut self, trailers: Box) -> Result<(), Self::Error>; } #[cfg(test)] @@ -186,16 +273,18 @@ mod tests { struct MockEncoder; impl Encoder for MockEncoder { - async fn write_response(&mut self, _: Response) -> eyre::Result<()> { + type Error = BX; + + async fn write_response(&mut self, _: Response) -> Result<(), Self::Error> { Ok(()) } - async fn write_body_chunk(&mut self, _: Piece) -> eyre::Result<()> { + async fn write_body_chunk(&mut self, _: Piece) -> Result<(), Self::Error> { Ok(()) } - async fn write_body_end(&mut self) -> eyre::Result<()> { + async fn write_body_end(&mut self) -> Result<(), Self::Error> { Ok(()) } - async fn write_trailers(&mut self, _: Box) -> eyre::Result<()> { + async fn write_trailers(&mut self, _: Box) -> Result<(), Self::Error> { Ok(()) } } @@ -219,7 +308,10 @@ mod tests { responder.write_chunk(b"12345".into()).await.unwrap(); let result = responder.finish_body(None).await; assert!(result.is_err()); - assert!(matches!(result, Err(e) if e.to_string().contains("content-length mismatch"))); + assert!(matches!( + result, + Err(ResponderError::BodyLengthDoesNotMatchAnnouncedContentLength { .. }) + )); // Test writing more bytes than announced let encoder = MockEncoder; @@ -230,7 +322,10 @@ mod tests { responder.write_chunk(b"12345678901".into()).await.unwrap(); let result = responder.finish_body(None).await; assert!(result.is_err()); - assert!(matches!(result, Err(e) if e.to_string().contains("content-length mismatch"))); + assert!(matches!( + result, + Err(ResponderError::BodyLengthDoesNotMatchAnnouncedContentLength { .. }) + )); } } } diff --git a/crates/loona/src/types/mod.rs b/crates/loona/src/types/mod.rs index 966ccfb3..d8b3588d 100644 --- a/crates/loona/src/types/mod.rs +++ b/crates/loona/src/types/mod.rs @@ -11,6 +11,8 @@ pub use headers::*; mod method; pub use method::*; +use crate::{error::NeverError, util::ReadAndParseError}; + /// An HTTP request #[derive(Clone)] pub struct Request { @@ -105,73 +107,64 @@ pub enum BodyChunk { } #[derive(Debug, thiserror::Error)] -pub struct BodyError { - reason: BodyErrorReason, - context: Option>, -} - -impl fmt::Display for BodyError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "body error: {:?}", self.reason)?; - if let Some(context) = &self.context { - write!(f, " ({context:?})") - } else { - Ok(()) - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum BodyErrorReason { - // next_chunk() was called after an error was returned +#[non_exhaustive] +pub enum BodyError { + /// next_chunk() was called after an error was returned + #[error("next_chunk() was called after an error was returned")] CalledNextChunkAfterError, - // while doing chunked transfer-encoding, we expected a chunk size - // but the connection closed/errored in some way + /// while doing chunked transfer-encoding, we expected a chunk size + /// but the connection closed/errored in some way + #[error("connection closed while reading chunk size")] ClosedWhileReadingChunkSize, - // while doing chunked transfer-encoding, we expected a chunk size, - // but what we read wasn't a hex number followed by CRLF + /// while doing chunked transfer-encoding, we expected a chunk size, + /// but what we read wasn't a hex number followed by CRLF + #[error("invalid chunk size")] InvalidChunkSize, - // while doing chunked transfer-encoding, the connection was closed - // in the middle of reading a chunk's data + /// while doing chunked transfer-encoding, the connection was closed + /// in the middle of reading a chunk's data + #[error("connection closed while reading chunk data")] ClosedWhileReadingChunkData, - // while reading a content-length body, the connection was closed + /// while reading a content-length body, the connection was closed + #[error("connection closed while reading content-length body")] ClosedWhileReadingContentLength, - // while doing chunked transfer-encoding, there was a read error - // in the middle of reading a chunk's data - ErrorWhileReadingChunkData, + /// while doing chunked transfer-encoding, there was a read error + /// in the middle of reading a chunk's data + #[error("error while reading chunk data: {0}")] + ErrorWhileReadingChunkData(std::io::Error), - // while doing chunked transfer-encoding, the connection was closed - // in the middle of reading the + /// while doing chunked transfer-encoding, the connection was closed + /// in the middle of reading the + #[error("connection closed while reading chunk terminator")] ClosedWhileReadingChunkTerminator, - // while doing chunked transfer-encoding, we read the chunk size, - // then that much data, but then encountered something other than - // a CRLF - InvalidChunkTerminator, + /// while doing chunked transfer-encoding, we read the chunk size, + /// then that much data, but then encountered something other than + /// a CRLF + #[error("invalid chunk terminator: {0}")] + InvalidChunkTerminator(#[from] ReadAndParseError), - // `write_chunk` was called but no content-length was announced, and - // no chunked transfer-encoding was announced + /// `write_chunk` was called but no content-length was announced, and + /// no chunked transfer-encoding was announced + #[error("write_chunk called when no body was expected")] CalledWriteBodyChunkWhenNoBodyWasExpected, -} -impl BodyErrorReason { - pub fn as_err(self) -> BodyError { - BodyError { - reason: self, - context: None, - } - } + /// Allocation failed + #[error("allocation failed: {0}")] + Alloc(#[from] buffet::bufpool::Error), - pub fn with_cx(self, context: impl Debug + Send + Sync + 'static) -> BodyError { - BodyError { - reason: self, - context: Some(Box::new(context)), - } + /// I/O error while writing + #[error("I/O error while writing: {0}")] + WriteError(std::io::Error), +} + +impl AsRef for BodyError { + fn as_ref(&self) -> &(dyn std::error::Error + 'static) { + self } } @@ -180,12 +173,16 @@ pub trait Body: Debug where Self: Sized, { + type Error: std::error::Error + 'static; + fn content_len(&self) -> Option; fn eof(&self) -> bool; - async fn next_chunk(&mut self) -> eyre::Result; + async fn next_chunk(&mut self) -> Result; } impl Body for () { + type Error = NeverError; + fn content_len(&self) -> Option { Some(0) } @@ -194,7 +191,35 @@ impl Body for () { true } - async fn next_chunk(&mut self) -> eyre::Result { + async fn next_chunk(&mut self) -> Result { Ok(BodyChunk::Done { trailers: None }) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServeOutcome { + /// HTTP/1.1 only: The request we handled had a `connection: close` header + ClientRequestedConnectionClose, + + /// HTTP/1.1 only: The request we handled had a `connection: close` header + ServerRequestedConnectionClose, + + // Client closed connection before sending a second request + /// (without requesting connection close) + ClientClosedConnectionBetweenRequests, + + /// HTTP/1.1 only: Client didn't speak HTTP/1.1 (missing/invalid request + /// line) + ClientDidntSpeakHttp11, + + // We refused to service a request because it was too large. Because it's HTTP/1.1, + /// we had to close the entire connection. + RequestHeadersTooLargeOnHttp1Conn, + + /// HTTP/2 only: Client didn't speak HTTP/2 (missing/invalid request line) + ClientDidntSpeakHttp2, + + /// HTTP/2 only: Client sent a GOAWAY frame, and we've sent a response to + /// the client + SuccessfulHttp2GracefulShutdown, +} diff --git a/crates/loona/src/util.rs b/crates/loona/src/util.rs index 4caac6f3..9b01b885 100644 --- a/crates/loona/src/util.rs +++ b/crates/loona/src/util.rs @@ -1,18 +1,41 @@ -use eyre::Context; use nom::IResult; use pretty_hex::PrettyHex; use tracing::{debug, trace}; use buffet::{ReadOwned, Roll, RollMut}; +use thiserror::Error; + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum ReadAndParseError { + /// Allocation error + #[error("Allocation error: {0}")] + Alloc(#[from] buffet::bufpool::Error), + + /// Read error + #[error("Read error: {0}")] + ReadError(#[from] std::io::Error), + + /// Buffer limit reached while parsing + #[error("Buffer limit reached while parsing (limit: {limit})")] + BufferLimitReachedWhileParsing { limit: usize }, + + /// Parsing error + // TODO: should we pass any amount of detail here? + #[error("Parsing error in parser: {parser}")] + ParsingError { parser: &'static str }, +} + /// Returns `None` on EOF, error if partially parsed message. pub(crate) async fn read_and_parse( + parser_name: &'static str, parser: Parser, stream: &mut impl ReadOwned, mut buf: RollMut, max_len: usize, // TODO: proper error handling, no eyre::Result -) -> eyre::Result> +) -> Result, ReadAndParseError> where Parser: Fn(Roll) -> IResult, { @@ -37,7 +60,9 @@ where let res; let read_limit = max_len - buf.len(); if buf.len() >= max_len { - return Err(SemanticError::BufferLimitReachedWhileParsing.into()); + return Err(ReadAndParseError::BufferLimitReachedWhileParsing { + limit: max_len, + }); } if buf.cap() == 0 { @@ -51,15 +76,12 @@ where ); (res, buf) = buf.read_into(read_limit, stream).await; - let n = res.wrap_err_with(|| { - format!( - "read_into for read_and_parse::<{}>", - std::any::type_name::() - ) - })?; + let n = res.map_err(ReadAndParseError::ReadError)?; if n == 0 { if !buf.is_empty() { - return Err(eyre::eyre!("unexpected EOF")); + return Err(ReadAndParseError::ReadError( + std::io::ErrorKind::UnexpectedEof.into(), + )); } else { return Ok(None); } @@ -71,25 +93,11 @@ where debug!(?err, "parsing error"); debug!(input = %e.input.to_string_lossy(), "input was"); } - return Err(eyre::eyre!("parsing error: {err}")); + return Err(ReadAndParseError::ParsingError { + parser: parser_name, + }); } } }; } } - -#[derive(thiserror::Error, Debug)] -pub(crate) enum SemanticError { - #[error("buffering limit reached while parsing")] - BufferLimitReachedWhileParsing, -} - -impl SemanticError { - pub(crate) fn as_http_response(&self) -> &'static [u8] { - match self { - Self::BufferLimitReachedWhileParsing => { - b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n" - } - } - } -} diff --git a/crates/loona/tests/helpers/mod.rs b/crates/loona/tests/helpers/mod.rs index b306cb74..e3b25903 100644 --- a/crates/loona/tests/helpers/mod.rs +++ b/crates/loona/tests/helpers/mod.rs @@ -1,14 +1,15 @@ use std::future::Future; +use b_x::BX; + pub(crate) mod tracing_common; -pub(crate) fn run(test: impl Future>) { - color_eyre::install().unwrap(); +pub(crate) fn run(test: impl Future>) { loona::buffet::start(async { tracing_common::setup_tracing(); if let Err(e) = test.await { - panic!("Error: {e:?}"); + panic!("Error: {e}"); } }); } diff --git a/crates/loona/tests/httpwg.rs b/crates/loona/tests/httpwg.rs index b4ec6f73..0cc97d6c 100644 --- a/crates/loona/tests/httpwg.rs +++ b/crates/loona/tests/httpwg.rs @@ -1,8 +1,13 @@ +use std::error::Error as StdError; use std::rc::Rc; +use b_x::{BxForResults, BX}; use buffet::{IntoHalves, PipeRead, PipeWrite, ReadOwned, RollMut, WriteOwned}; use http::StatusCode; -use loona::{Body, BodyChunk, Encoder, ExpectResponseHeaders, Responder, Response, ResponseDone}; +use loona::{ + Body, BodyChunk, Encoder, ExpectResponseHeaders, Responder, Response, ResponseDone, + ServerDriver, +}; use tracing::Level; use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::SubscriberInitExt}; @@ -10,8 +15,6 @@ use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::Subscriber /// globals. But it will work with `cargo nextest`, and that's what loona is /// standardizing on. pub(crate) fn setup_tracing_and_error_reporting() { - color_eyre::install().unwrap(); - let targets = if let Ok(rust_log) = std::env::var("RUST_LOG") { rust_log.parse::().unwrap() } else { @@ -36,13 +39,19 @@ pub(crate) fn setup_tracing_and_error_reporting() { struct TestDriver; -impl loona::ServerDriver for TestDriver { - async fn handle( +impl ServerDriver for TestDriver +where + OurEncoder: Encoder, + ::Error: AsRef, +{ + type Error = BX; + + async fn handle( &self, _req: loona::Request, req_body: &mut impl Body, - mut res: Responder, - ) -> eyre::Result> { + mut res: Responder, + ) -> Result, BX> { // if the client sent `expect: 100-continue`, we must send a 100 status code if let Some(h) = _req.headers.get(http::header::EXPECT) { if &h[..] == b"100-continue" { @@ -57,7 +66,7 @@ impl loona::ServerDriver for TestDriver { // then read the full request body let mut req_body_len = 0; loop { - let chunk = req_body.next_chunk().await?; + let chunk = req_body.next_chunk().await.bx()?; match chunk { BodyChunk::Done { trailers } => { // yey @@ -113,7 +122,7 @@ pub fn start_server() -> httpwg::Conn> { let io = (server_read, server_write); loona::h2::serve(io, server_conf, client_buf, driver).await?; tracing::debug!("http/2 server done"); - Ok::<_, eyre::Report>(()) + Ok::<_, BX>(()) }; buffet::spawn(async move { diff --git a/crates/loona/tests/integration_test.rs b/crates/loona/tests/integration_test.rs index 9dbf59a2..6c7beaf9 100644 --- a/crates/loona/tests/integration_test.rs +++ b/crates/loona/tests/integration_test.rs @@ -1,5 +1,6 @@ mod helpers; +use b_x::{BxForResults, BX}; use bytes::BytesMut; use http::{header, StatusCode}; use httparse::{Status, EMPTY_HEADER}; @@ -11,9 +12,10 @@ use loona::{ }; use pretty_assertions::assert_eq; use pretty_hex::PrettyHex; -use std::{future::Future, net::SocketAddr, process::Command, rc::Rc, time::Duration}; +use std::{future::Future, net::SocketAddr, rc::Rc, time::Duration}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; +use tokio::process::Command; use tracing::debug; mod proxy; @@ -46,13 +48,18 @@ fn serve_api() { struct TestDriver; - impl ServerDriver for TestDriver { - async fn handle( + impl ServerDriver for TestDriver + where + OurEncoder: Encoder, + { + type Error = BX; + + async fn handle( &self, _req: loona::Request, _req_body: &mut impl Body, - mut res: Responder, - ) -> eyre::Result> { + mut res: Responder, + ) -> b_x::Result> { let mut buf = RollMut::alloc()?; buf.put(b"Continue")?; @@ -107,7 +114,7 @@ fn serve_api() { let mut headers = [EMPTY_HEADER; 16]; let mut res = httparse::Response::new(&mut headers[..]); - let body_offset = match res.parse(&res_buf[..])? { + let body_offset = match res.parse(&res_buf[..]).bx()? { Status::Complete(off) => off, Status::Partial => { debug!("partial response, continuing"); @@ -132,7 +139,10 @@ fn serve_api() { drop(client_write); - tokio::time::timeout(Duration::from_secs(5), serve_fut).await???; + tokio::time::timeout(Duration::from_secs(5), serve_fut) + .await + .bx()? + .bx()??; Ok(()) }) @@ -154,8 +164,9 @@ fn request_api() { impl h1::ClientDriver for TestDriver { type Return = (); + type Error = BX; - async fn on_informational_response(&mut self, _res: Response) -> eyre::Result<()> { + async fn on_informational_response(&mut self, _res: Response) -> b_x::Result<()> { todo!("got informational response!") } @@ -163,14 +174,14 @@ fn request_api() { self, res: Response, body: &mut impl Body, - ) -> eyre::Result { + ) -> b_x::Result { debug!( "got final response! content length = {:?}, is chunked = {}", res.headers.content_length(), res.headers.is_chunked_transfer_encoding(), ); - while let BodyChunk::Chunk(chunk) = body.next_chunk().await? { + while let BodyChunk::Chunk(chunk) = body.next_chunk().await.bx()? { debug!("got a chunk: {:?}", chunk.hex_dump()); } @@ -197,7 +208,7 @@ fn request_api() { let mut headers = [EMPTY_HEADER; 16]; let mut req = httparse::Request::new(&mut headers[..]); - let body_offset = match req.parse(&req_buf[..])? { + let body_offset = match req.parse(&req_buf[..]).bx()? { Status::Complete(off) => off, Status::Partial => { debug!("partial request, continuing"); @@ -221,7 +232,10 @@ fn request_api() { server_write.write_all_owned(body).await?; drop(server_write); - tokio::time::timeout(Duration::from_secs(5), request_fut).await???; + tokio::time::timeout(Duration::from_secs(5), request_fut) + .await + .bx()? + .bx()??; Ok(()) }) @@ -230,7 +244,7 @@ fn request_api() { #[test] fn proxy_statuses() { #[allow(drop_bounds)] - async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let mut socket = TcpStream::connect(ln_addr).await?; socket.set_nodelay(true)?; @@ -255,7 +269,7 @@ fn proxy_statuses() { let mut headers = [EMPTY_HEADER; 16]; let mut res = httparse::Response::new(&mut headers[..]); - let _body_offset = match res.parse(&buf[..])? { + let _body_offset = match res.parse(&buf[..]).bx()? { Status::Complete(off) => off, Status::Partial => continue 'read_response, }; @@ -286,7 +300,7 @@ fn proxy_statuses() { #[test] fn proxy_echo_body_content_len() { #[allow(drop_bounds)] - async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let socket = TcpStream::connect(ln_addr).await?; socket.set_nodelay(true)?; @@ -310,7 +324,7 @@ fn proxy_echo_body_content_len() { write.write_all_owned(body.as_bytes()).await?; write.flush().await?; - Ok::<(), eyre::Report>(()) + Ok::<(), BX>(()) }; loona::buffet::spawn(async move { if let Err(e) = send_fut.await { @@ -329,7 +343,7 @@ fn proxy_echo_body_content_len() { let mut headers = [EMPTY_HEADER; 16]; let mut res = httparse::Response::new(&mut headers[..]); - let body_offset = match res.parse(&buf[..])? { + let body_offset = match res.parse(&buf[..]).bx()? { Status::Complete(off) => off, Status::Partial => continue 'read_response, }; @@ -391,7 +405,7 @@ fn proxy_echo_body_content_len() { #[test] fn proxy_echo_body_chunked() { #[allow(drop_bounds)] - async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let socket = TcpStream::connect(ln_addr).await?; socket.set_nodelay(true)?; @@ -419,7 +433,7 @@ fn proxy_echo_body_chunked() { write.write_all_owned(&b"0\r\n\r\n"[..]).await?; write.flush().await?; - Ok::<(), eyre::Report>(()) + Ok::<(), BX>(()) }; loona::buffet::spawn(async move { if let Err(e) = send_fut.await { @@ -438,7 +452,7 @@ fn proxy_echo_body_chunked() { let mut headers = [EMPTY_HEADER; 16]; let mut res = httparse::Response::new(&mut headers[..]); - let body_offset = match res.parse(&buf[..])? { + let body_offset = match res.parse(&buf[..]).bx()? { Status::Complete(off) => off, Status::Partial => continue 'read_response, }; @@ -555,7 +569,7 @@ fn curl_echo_body_chunked() { fn curl_echo_body(typ: BodyType) { #[allow(drop_bounds)] - fn client(typ: BodyType, ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(typ: BodyType, ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let req_body = "Please return to sender"; let mut cmd = Command::new("curl"); @@ -573,7 +587,8 @@ fn curl_echo_body(typ: BodyType) { cmd.arg("--header").arg("transfer-encoding: chunked"); } - let res_body = cmd.output_assert_success().stdout; + let output = cmd.output_assert_success().await; + let res_body = output.stdout; debug!("Got body: {:?}", res_body.hex_dump()); assert_eq!(res_body.len(), req_body.len()); @@ -585,12 +600,7 @@ fn curl_echo_body(typ: BodyType) { helpers::run(async move { let (upstream_addr, _upstream_guard) = testbed::start().await?; let (ln_addr, guard, proxy_fut) = proxy::start(upstream_addr).await?; - let client_fut = async move { - tokio::task::spawn_blocking(move || client(typ, ln_addr, guard)) - .await - .unwrap() - }; - + let client_fut = client(typ, ln_addr, guard); tokio::try_join!(proxy_fut, client_fut)?; debug!("everything has been joined"); @@ -610,9 +620,9 @@ fn curl_echo_body_noproxy_chunked() { fn curl_echo_body_noproxy(typ: BodyType) { #[allow(drop_bounds)] - fn client(typ: BodyType, ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(typ: BodyType, ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let req_body = "Please return to sender"; - let mut cmd = Command::new("curl"); + let mut cmd = tokio::process::Command::new("curl"); cmd.arg("--silent"); cmd.arg("--fail-with-body"); @@ -628,7 +638,8 @@ fn curl_echo_body_noproxy(typ: BodyType) { cmd.arg("--header").arg("transfer-encoding: chunked"); } - let res_body = cmd.output_assert_success().stdout; + let output = cmd.output_assert_success().await; + let res_body = output.stdout; debug!("Got body: {:?}", res_body.hex_dump()); assert_eq!(res_body.len(), req_body.len()); @@ -637,11 +648,8 @@ fn curl_echo_body_noproxy(typ: BodyType) { Ok(()) } - async fn start_server() -> eyre::Result<( - SocketAddr, - impl Drop, - impl Future>, - )> { + async fn start_server( + ) -> b_x::Result<(SocketAddr, impl Drop, impl Future>)> { let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); let ln = loona::buffet::net::TcpListener::bind("127.0.0.1:0".parse()?).await?; @@ -649,13 +657,18 @@ fn curl_echo_body_noproxy(typ: BodyType) { struct TestDriver; - impl ServerDriver for TestDriver { - async fn handle( + impl ServerDriver for TestDriver + where + OurEncoder: Encoder, + { + type Error = BX; + + async fn handle( &self, req: Request, req_body: &mut impl Body, - mut respond: Responder, - ) -> eyre::Result> { + mut respond: Responder, + ) -> b_x::Result> { if req.headers.expects_100_continue() { debug!("Sending 100-continue"); let res = Response { @@ -677,7 +690,8 @@ fn curl_echo_body_noproxy(typ: BodyType) { }; let respond = respond .write_final_response_with_body(res, req_body) - .await?; + .await + .bx()?; debug!("Wrote final response"); Ok(respond) @@ -738,11 +752,7 @@ fn curl_echo_body_noproxy(typ: BodyType) { helpers::run(async move { let (ln_addr, guard, server_fut) = start_server().await?; - let client_fut = async move { - tokio::task::spawn_blocking(move || client(typ, ln_addr, guard)) - .await - .unwrap() - }; + let client_fut = client(typ, ln_addr, guard); tokio::try_join!(server_fut, client_fut)?; debug!("everything has been joined"); @@ -754,9 +764,9 @@ fn curl_echo_body_noproxy(typ: BodyType) { #[test] fn h2_basic_post() { #[allow(drop_bounds)] - fn client(ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { + async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { let req_body = "Please return to sender"; - let mut cmd = Command::new("curl"); + let mut cmd = tokio::process::Command::new("curl"); cmd.arg("--silent"); cmd.arg("--fail-with-body"); @@ -766,7 +776,8 @@ fn h2_basic_post() { cmd.arg("--header") .arg("content-type: application/octet-stream"); - let res_body = cmd.output_assert_success().stdout; + let output = cmd.output_assert_success().await; + let res_body = output.stdout; debug!("Got body: {:?}", res_body.hex_dump()); assert_eq!(res_body.len(), req_body.len()); @@ -775,11 +786,8 @@ fn h2_basic_post() { Ok(()) } - async fn start_server() -> eyre::Result<( - SocketAddr, - impl Drop, - impl Future>, - )> { + async fn start_server( + ) -> b_x::Result<(SocketAddr, impl Drop, impl Future>)> { let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); let listen_port: u16 = std::env::var("LISTEN_PORT") @@ -792,13 +800,18 @@ fn h2_basic_post() { struct TestDriver; - impl ServerDriver for TestDriver { - async fn handle( + impl ServerDriver for TestDriver + where + OurEncoder: Encoder, + { + type Error = BX; + + async fn handle( &self, req: Request, req_body: &mut impl Body, - respond: Responder, - ) -> eyre::Result> { + respond: Responder, + ) -> b_x::Result> { debug!("Got request {req:#?}"); debug!("Writing final response"); @@ -813,7 +826,8 @@ fn h2_basic_post() { }; let respond = respond .write_final_response_with_body(res, req_body) - .await?; + .await + .bx()?; debug!("Wrote final response"); Ok(respond) @@ -876,11 +890,7 @@ fn h2_basic_post() { helpers::run(async move { let (ln_addr, guard, server_fut) = start_server().await?; - let client_fut = async move { - tokio::task::spawn_blocking(move || client(ln_addr, guard)) - .await - .unwrap() - }; + let client_fut = client(ln_addr, guard); tokio::try_join!(server_fut, client_fut)?; debug!("everything has been joined"); @@ -901,6 +911,8 @@ impl Default for SampleBody { } impl Body for SampleBody { + type Error = BX; + fn content_len(&self) -> Option { None } @@ -909,7 +921,7 @@ impl Body for SampleBody { self.chunks_remain == 0 } - async fn next_chunk(&mut self) -> eyre::Result { + async fn next_chunk(&mut self) -> b_x::Result { let c = match self.chunks_remain { 0 => BodyChunk::Done { trailers: None }, _ => BodyChunk::Chunk( @@ -928,15 +940,17 @@ impl Body for SampleBody { #[test] fn h2_basic_get() { #[allow(drop_bounds)] - fn client(ln_addr: SocketAddr, _guard: impl Drop) -> eyre::Result<()> { - let mut cmd = Command::new("curl"); + async fn client(ln_addr: SocketAddr, _guard: impl Drop) -> b_x::Result<()> { + let mut cmd = tokio::process::Command::new("curl"); cmd.arg("--silent"); cmd.arg("--fail-with-body"); cmd.arg("--http2-prior-knowledge"); cmd.arg(format!("http://{ln_addr}/stream-big-body")); - let res_body = cmd.output_assert_success().stdout; + let output = cmd.output_assert_success().await; + let res_body = output.stdout; + let ref_body = "this is a big chunk".repeat(256).repeat(128); assert_eq!(res_body.len(), ref_body.len()); assert_eq!(String::from_utf8(res_body).unwrap(), ref_body); @@ -944,11 +958,8 @@ fn h2_basic_get() { Ok(()) } - async fn start_server() -> eyre::Result<( - SocketAddr, - impl Drop, - impl Future>, - )> { + async fn start_server( + ) -> b_x::Result<(SocketAddr, impl Drop, impl Future>)> { let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); let listen_port: u16 = std::env::var("LISTEN_PORT") @@ -961,13 +972,18 @@ fn h2_basic_get() { struct TestDriver; - impl ServerDriver for TestDriver { - async fn handle( + impl ServerDriver for TestDriver + where + OurEncoder: Encoder, + { + type Error = BX; + + async fn handle( &self, req: Request, _req_body: &mut impl Body, - respond: Responder, - ) -> eyre::Result> { + respond: Responder, + ) -> b_x::Result> { debug!("Got request {req:#?}"); debug!("Writing final response"); @@ -982,7 +998,8 @@ fn h2_basic_get() { }; let respond = respond .write_final_response_with_body(res, &mut SampleBody::default()) - .await?; + .await + .bx()?; debug!("Wrote final response"); Ok(respond) @@ -1045,11 +1062,7 @@ fn h2_basic_get() { helpers::run(async move { let (ln_addr, guard, server_fut) = start_server().await?; - let client_fut = async move { - tokio::task::spawn_blocking(move || client(ln_addr, guard)) - .await - .unwrap() - }; + let client_fut = client(ln_addr, guard); tokio::try_join!(server_fut, client_fut)?; debug!("everything has been joined"); @@ -1059,12 +1072,12 @@ fn h2_basic_get() { } trait CommandExt { - fn output_assert_success(&mut self) -> std::process::Output; + async fn output_assert_success(&mut self) -> std::process::Output; } impl CommandExt for Command { - fn output_assert_success(&mut self) -> std::process::Output { - let output = self.output().unwrap(); + async fn output_assert_success(&mut self) -> std::process::Output { + let output = self.output().await.unwrap(); if !output.status.success() { // print stderr eprintln!("command failed with status: {:?}", output.status); diff --git a/crates/loona/tests/proxy.rs b/crates/loona/tests/proxy.rs index 9da3f38e..23e9579d 100644 --- a/crates/loona/tests/proxy.rs +++ b/crates/loona/tests/proxy.rs @@ -1,3 +1,4 @@ +use b_x::{BxForResults, BX}; use http::StatusCode; use loona::{ buffet::{ @@ -17,13 +18,18 @@ pub struct ProxyDriver { pub pool: TransportPool, } -impl ServerDriver for ProxyDriver { - async fn handle( +impl ServerDriver for ProxyDriver +where + OurEncoder: Encoder, +{ + type Error = BX; + + async fn handle( &self, req: loona::Request, req_body: &mut impl Body, - mut respond: Responder, - ) -> eyre::Result> { + mut respond: Responder, + ) -> Result, BX> { if req.headers.expects_100_continue() { debug!("Sending 100-continue"); let res = Response { @@ -63,20 +69,21 @@ impl ServerDriver for ProxyDriver { } } -struct ProxyClientDriver +struct ProxyClientDriver where - E: Encoder, + OurEncoder: Encoder, { - respond: Responder, + respond: Responder, } -impl h1::ClientDriver for ProxyClientDriver +impl h1::ClientDriver for ProxyClientDriver where - E: Encoder, + OurEncoder: Encoder, { - type Return = Responder; + type Return = Responder; + type Error = BX; - async fn on_informational_response(&mut self, res: Response) -> eyre::Result<()> { + async fn on_informational_response(&mut self, res: Response) -> Result<(), Self::Error> { debug!("Got informational response {}", res.status); Ok(()) } @@ -85,12 +92,12 @@ where self, res: Response, body: &mut impl Body, - ) -> eyre::Result { + ) -> Result { let respond = self.respond; let mut respond = respond.write_final_response(res).await?; let trailers = loop { - match body.next_chunk().await? { + match body.next_chunk().await.bx()? { BodyChunk::Chunk(chunk) => { respond.write_chunk(chunk).await?; } @@ -110,11 +117,7 @@ where pub async fn start( upstream_addr: SocketAddr, -) -> eyre::Result<( - SocketAddr, - impl Drop, - impl Future>, -)> { +) -> b_x::Result<(SocketAddr, impl Drop, impl Future>)> { let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); let ln = loona::buffet::net::TcpListener::bind("127.0.0.1:0".parse()?).await?; diff --git a/crates/loona/tests/testbed.rs b/crates/loona/tests/testbed.rs index 28e398ef..0e1fc431 100644 --- a/crates/loona/tests/testbed.rs +++ b/crates/loona/tests/testbed.rs @@ -1,5 +1,6 @@ use std::{any::Any, net::SocketAddr, path::PathBuf, process::Stdio}; +use b_x::BxForResults; #[cfg(target_os = "linux")] use libc::{prctl, PR_SET_PDEATHSIG, SIGKILL}; @@ -15,7 +16,7 @@ const EXE_FILE_EXT: &str = ".exe"; #[cfg(not(target_os = "windows"))] const EXE_FILE_EXT: &str = ""; -pub async fn start() -> eyre::Result<(SocketAddr, impl Any)> { +pub async fn start() -> b_x::Result<(SocketAddr, impl Any)> { let (addr_tx, addr_rx) = tokio::sync::oneshot::channel::(); let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); @@ -58,6 +59,6 @@ pub async fn start() -> eyre::Result<(SocketAddr, impl Any)> { } }); - let upstream_addr = addr_rx.await?; + let upstream_addr = addr_rx.await.bx()?; Ok((upstream_addr, child)) }