diff --git a/Cargo.lock b/Cargo.lock index aa3d00ec..8c293413 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "fluke-httpwg-server" +version = "0.1.0" +dependencies = [ + "color-eyre", + "fluke", + "fluke-buffet", + "tracing", + "tracing-subscriber", +] + [[package]] name = "fluke-hyper-testbed" version = "0.1.0" @@ -531,6 +542,19 @@ dependencies = [ "tracing", ] +[[package]] +name = "httpwg-cli" +version = "0.1.0" +dependencies = [ + "color-eyre", + "eyre", + "fluke-buffet", + "httpwg", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "httpwg-gen" version = "0.1.0" @@ -1249,9 +1273,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -1268,9 +1292,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", diff --git a/Justfile b/Justfile index 36cc647c..769df200 100644 --- a/Justfile +++ b/Justfile @@ -21,6 +21,9 @@ cov: build-testbed: cargo build --release -p fluke-hyper-testbed +t *args: + just test {{args}} + # Run all tests with cargo nextest test *args: #!/bin/bash diff --git a/crates/fluke-buffet/src/roll.rs b/crates/fluke-buffet/src/roll.rs index 5e4adc69..a48a4928 100644 --- a/crates/fluke-buffet/src/roll.rs +++ b/crates/fluke-buffet/src/roll.rs @@ -52,6 +52,14 @@ impl Debug for StorageMut { } impl StorageMut { + #[inline(always)] + fn off(&self) -> u32 { + match self { + StorageMut::Buf(b) => b.off as _, + StorageMut::Box(b) => b.off, + } + } + #[inline(always)] fn cap(&self) -> usize { match self { @@ -132,6 +140,9 @@ impl RollMut { pub fn grow(&mut self) { let old_cap = self.storage.cap(); let new_cap = old_cap * 2; + + tracing::trace!("growing buffer from {} to {}", old_cap, new_cap); + // TODO: optimize via `MaybeUninit`? let b = vec![0; new_cap].into_boxed_slice(); let mut bs = BoxStorage { @@ -148,16 +159,22 @@ impl RollMut { /// Reallocates the backing storage for this buffer, copying the filled /// portion into it. Panics if `len() == storage_size()`, in which case /// reallocating won't do much good - pub fn realloc(&mut self) -> Result<()> { + /// + /// Also panics if we're using buf storage and the offset is zero + /// (which means reallocating would not free up any space) + pub fn compact(&mut self) -> Result<()> { assert!(self.len() != self.storage_size()); + tracing::trace!("compacting"); let next_storage = match &self.storage { - StorageMut::Buf(_) => { + StorageMut::Buf(bm) => { + assert!(bm.off != 0); let mut next_b = BufMut::alloc()?; next_b[..self.len()].copy_from_slice(&self[..]); StorageMut::Buf(next_b) } StorageMut::Box(b) => { + tracing::trace!("reallocating, storage is box"); if self.len() > BUF_SIZE as usize { // TODO: optimize via `MaybeUninit`? let mut next_b = vec![0; b.cap()].into_boxed_slice(); @@ -186,13 +203,14 @@ impl RollMut { /// to `realloc`. pub fn reserve(&mut self) -> Result<()> { if self.len() < self.cap() { + tracing::trace!("reserve: len < cap, no need to reserve anything"); return Ok(()); } - if self.len() < self.storage_size() { - // we don't need to go up a buffer size + if self.storage.off() > 0 { + // let's try to compact first trace!(len = %self.len(), cap = %self.cap(), storage_size = %self.storage_size(), "in reserve: reallocating"); - self.realloc()? + self.compact()? } else { trace!(len = %self.len(), cap = %self.cap(), storage_size = %self.storage_size(), "in reserve: growing"); self.grow() @@ -203,15 +221,28 @@ impl RollMut { /// Make sure we can hold "request_len" pub fn reserve_at_least(&mut self, requested_len: usize) -> Result<()> { - while self.cap() < requested_len { - if self.cap() < self.storage_size() { - // we don't need to go up a buffer size - self.realloc()? - } else { - self.grow() - } + if requested_len <= self.cap() { + tracing::trace!(%requested_len, cap = %self.cap(), "reserve_at_least: requested_len <= cap, no need to compact"); + return Ok(()); } + if self.storage.off() > 0 && requested_len <= (BUF_SIZE as usize - self.len()) { + // we can compact the filled portion! + self.compact()?; + } else { + // we need to allocate box storage of the right size + let new_storage_size = + std::cmp::max(self.storage_size() * 2, requested_len + self.len()); + let mut new_b = vec![0u8; new_storage_size].into_boxed_slice(); + // copy the filled portion + new_b[..self.len()].copy_from_slice(&self[..]); + self.storage = StorageMut::Box(BoxStorage { + buf: Rc::new(UnsafeCell::new(new_b)), + off: 0, + }); + } + + assert!(self.cap() >= requested_len); Ok(()) } @@ -271,7 +302,8 @@ impl RollMut { (res, read_into.buf) } - /// Put a slice into this buffer, fails if the slice doesn't fit in the buffer's capacity + /// Put a slice into this buffer, fails if the slice doesn't fit in the + /// buffer's capacity pub fn put(&mut self, s: impl AsRef<[u8]>) -> Result<()> { let s = s.as_ref(); @@ -981,7 +1013,7 @@ mod tests { rm.take_all(); assert_eq!(rm.cap(), init_cap - 5); - rm.realloc().unwrap(); + rm.compact().unwrap(); assert_eq!(rm.cap(), BUF_SIZE as usize); } @@ -1000,7 +1032,7 @@ mod tests { let put = "x".repeat(rm.cap() * 2 / 3); rm.put(&put).unwrap(); - rm.realloc().unwrap(); + rm.compact().unwrap(); assert_eq!(rm.storage_size(), BUF_SIZE as usize * 2); assert_eq!(rm.len(), put.len()); @@ -1336,4 +1368,16 @@ mod tests { let roll = rm.take_all(); assert_eq!(std::str::from_utf8(&roll).unwrap(), "hello"); } + + #[test] + fn test_reallocate_big() { + let mut rm = RollMut::alloc().unwrap(); + rm.put(b"baba yaga").unwrap(); + let filled = rm.filled(); + let (_frame, rest) = filled.split_at(4); + rm.keep(rest); + + rm.reserve_at_least(5263945).unwrap(); + assert!(rm.cap() >= 5263945); + } } diff --git a/crates/fluke-h2-parse/src/lib.rs b/crates/fluke-h2-parse/src/lib.rs index 9c6c60b1..74f94ced 100644 --- a/crates/fluke-h2-parse/src/lib.rs +++ b/crates/fluke-h2-parse/src/lib.rs @@ -273,7 +273,7 @@ impl fmt::Debug for Frame { FrameType::WindowUpdate => "WindowUpdate", FrameType::Continuation(_) => "Continuation", FrameType::Unknown(EncodedFrameType { ty, flags }) => { - return write!(f, "UnknownFrame({:#x}, {:#x})", ty, flags) + return write!(f, "UnknownFrame({:#x}, {:#x}, len={})", ty, flags, self.len) } }; let mut s = f.debug_struct(name); @@ -376,7 +376,7 @@ impl Frame { } /// Returns true if this frame is an ack - pub fn is_ack(self) -> bool { + pub fn is_ack(&self) -> bool { match self.frame_type { FrameType::Settings(flags) => flags.contains(SettingsFlags::Ack), FrameType::Ping(flags) => flags.contains(PingFlags::Ack), @@ -385,7 +385,7 @@ impl Frame { } /// Returns true if this frame has `EndHeaders` set - pub fn is_end_headers(self) -> bool { + pub fn is_end_headers(&self) -> bool { match self.frame_type { FrameType::Headers(flags) => flags.contains(HeadersFlags::EndHeaders), FrameType::Continuation(flags) => flags.contains(ContinuationFlags::EndHeaders), @@ -394,7 +394,7 @@ impl Frame { } /// Returns true if this frame has `EndStream` set - pub fn is_end_stream(self) -> bool { + pub fn is_end_stream(&self) -> bool { match self.frame_type { FrameType::Data(flags) => flags.contains(DataFlags::EndStream), FrameType::Headers(flags) => flags.contains(HeadersFlags::EndStream), diff --git a/crates/fluke-httpwg-server/Cargo.toml b/crates/fluke-httpwg-server/Cargo.toml new file mode 100644 index 00000000..3489569b --- /dev/null +++ b/crates/fluke-httpwg-server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "fluke-httpwg-server" +version = "0.1.0" +edition = "2021" + +[dependencies] +color-eyre = "0.6.3" +fluke = { version = "0.1.1", path = "../fluke" } +fluke-buffet = { version = "0.2.0", path = "../fluke-buffet" } +tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/crates/fluke-httpwg-server/src/main.rs b/crates/fluke-httpwg-server/src/main.rs new file mode 100644 index 00000000..d95d8d70 --- /dev/null +++ b/crates/fluke-httpwg-server/src/main.rs @@ -0,0 +1,128 @@ +use std::rc::Rc; + +use color_eyre::eyre; +use fluke::{ + http::{self, StatusCode}, + Body, BodyChunk, Encoder, ExpectResponseHeaders, Responder, Response, ResponseDone, +}; +use fluke_buffet::{IntoHalves, RollMut}; +use tracing::Level; +use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::SubscriberInitExt}; + +fn main() { + setup_tracing_and_error_reporting(); + + fluke_buffet::start(async move { + let ln = fluke_buffet::net::TcpListener::bind("127.0.0.1:8000".parse().unwrap()) + .await + .unwrap(); + + println!( + "Listening on {:?} for 'http2 prior knowledge' connections (no TLS)", + ln.local_addr().unwrap() + ); + + loop { + let (stream, addr) = ln.accept().await.unwrap(); + tracing::info!(?addr, "Accepted connection"); + + fluke_buffet::spawn(async move { + let server_conf = Rc::new(fluke::h2::ServerConf { + ..Default::default() + }); + + let client_buf = RollMut::alloc().unwrap(); + let driver = Rc::new(TestDriver); + let io = stream.into_halves(); + fluke::h2::serve(io, server_conf, client_buf, driver) + .await + .unwrap(); + tracing::debug!("http/2 server done"); + }); + } + }); +} + +fn setup_tracing_and_error_reporting() { + color_eyre::install().unwrap(); + + let targets = if let Ok(rust_log) = std::env::var("RUST_LOG") { + rust_log.parse::().unwrap() + } else { + Targets::new() + .with_default(Level::INFO) + .with_target("fluke", Level::DEBUG) + .with_target("httpwg", Level::DEBUG) + .with_target("want", Level::INFO) + }; + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(true) + .with_file(false) + .with_line_number(false) + .without_time(); + + tracing_subscriber::registry() + .with(targets) + .with(fmt_layer) + .init(); +} + +struct TestDriver; + +impl fluke::ServerDriver for TestDriver { + async fn handle( + &self, + _req: fluke::Request, + req_body: &mut impl Body, + mut res: Responder, + ) -> eyre::Result> { + // if the client sent `expect: 100-continue`, we must send a 100 status code + if let Some(h) = _req.headers.get(http::header::EXPECT) { + if &h[..] == b"100-continue" { + res.write_interim_response(Response { + status: StatusCode::CONTINUE, + ..Default::default() + }) + .await?; + } + } + + // then read the full request body + let mut req_body_len = 0; + loop { + let chunk = req_body.next_chunk().await?; + match chunk { + BodyChunk::Done { trailers } => { + // yey + if let Some(trailers) = trailers { + tracing::debug!(trailers_len = %trailers.len(), "received trailers"); + } + break; + } + BodyChunk::Chunk(chunk) => { + req_body_len += chunk.len(); + } + } + } + tracing::debug!(%req_body_len, "read request body"); + + tracing::trace!("writing final response"); + let mut res = res + .write_final_response(Response { + status: StatusCode::OK, + ..Default::default() + }) + .await?; + + tracing::trace!("writing body chunk"); + res.write_chunk("it's less dire to lose, than to lose oneself".into()) + .await?; + + tracing::trace!("finishing body (with no trailers)"); + let res = res.finish_body(None).await?; + + tracing::trace!("we're done"); + Ok(res) + } +} diff --git a/crates/fluke/src/h2/body.rs b/crates/fluke/src/h2/body.rs index 422e4cb1..676cea2b 100644 --- a/crates/fluke/src/h2/body.rs +++ b/crates/fluke/src/h2/body.rs @@ -5,29 +5,125 @@ use tokio::sync::mpsc; use crate::{Body, BodyChunk, Headers}; use fluke_buffet::Piece; -pub(crate) enum PieceOrTrailers { +use super::types::H2StreamError; + +/// Something we receive from an http/2 peer: pieces of the request +/// body, the final trailers, or perhaps an error! if the client doesn't +/// end up sending exactly the number of bytes they promised. +pub(crate) enum IncomingMessage { Piece(Piece), Trailers(Box), } +pub(crate) enum ChunkPosition { + NotLast, + Last, +} + pub(crate) struct StreamIncoming { - // TODO: don't allow access to tx, check against capacity first? - pub(crate) tx: mpsc::Sender, + tx: mpsc::Sender, + + // total bytes received, which we keep track of, because if the client + // announces a content-length and sends fewer or more bytes, we will + // error out. + pub(crate) total_received: u64, + pub(crate) content_length: Option, // incoming capacity (that we decide, we get to tell // the peer how much we can handle with window updates) pub(crate) capacity: i64, } +impl StreamIncoming { + pub(crate) fn new( + initial_window_size: u32, + content_length: Option, + piece_tx: mpsc::Sender>, + ) -> Self { + Self { + tx: piece_tx, + total_received: 0, + content_length, + capacity: initial_window_size as i64, + } + } + + pub(crate) async fn write_chunk( + &mut self, + chunk: Piece, + which: ChunkPosition, + ) -> Result<(), H2StreamError> { + match self.total_received.checked_add(chunk.len() as u64) { + Some(new_total) => { + self.total_received = new_total; + } + None => return Err(H2StreamError::OverflowWhileCalculatingContentLength), + } + + if let Some(content_length) = self.content_length { + if self.total_received > content_length { + return Err(H2StreamError::DataLengthDoesNotMatchContentLength { + data_length: self.total_received, + content_length, + }); + } + + if matches!(which, ChunkPosition::Last) && self.total_received != content_length { + return Err(H2StreamError::DataLengthDoesNotMatchContentLength { + data_length: self.total_received, + content_length, + }); + } + } + + if self + .tx + .send(Ok(IncomingMessage::Piece(chunk))) + .await + .is_err() + { + // the stream is being ignored, so let's reset it + return Err(H2StreamError::Cancel); + } + Ok(()) + } + + pub(crate) async fn write_trailers(&mut self, trailers: Headers) -> Result<(), H2StreamError> { + if let Some(content_length) = self.content_length { + if self.total_received != content_length { + return Err(H2StreamError::DataLengthDoesNotMatchContentLength { + data_length: self.total_received, + content_length, + }); + } + } + + let _ = self + .tx + .send(Ok(IncomingMessage::Trailers(Box::new(trailers)))) + .await; + + // TODO: keep track of what we've sent, panic if we're not in the right state. + + Ok(()) + } + + pub(crate) async fn send_error(&mut self, err: eyre::Report) { + let _ = self.tx.send(Err(err)).await; + } +} + // FIXME: don't use eyre, do proper error handling -pub(crate) type StreamIncomingItem = eyre::Result; +pub(crate) type IncomingMessagesResult = eyre::Result; #[derive(Debug)] pub(crate) struct H2Body { pub(crate) content_length: Option, + pub(crate) eof: bool, + // TODO: more specific error handling - pub(crate) rx: mpsc::Receiver, + pub(crate) rx: mpsc::Receiver, } impl Body for H2Body { @@ -44,9 +140,9 @@ impl Body for H2Body { BodyChunk::Done { trailers: None } } else { match self.rx.recv().await { - Some(maybe_piece_or_trailers) => match maybe_piece_or_trailers? { - PieceOrTrailers::Piece(piece) => BodyChunk::Chunk(piece), - PieceOrTrailers::Trailers(trailers) => { + Some(msg) => match msg? { + IncomingMessage::Piece(piece) => BodyChunk::Chunk(piece), + IncomingMessage::Trailers(trailers) => { self.eof = true; BodyChunk::Done { trailers: Some(trailers), diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 579711e5..7d4f0101 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -1,6 +1,6 @@ use std::{ borrow::Cow, - collections::HashSet, + collections::{hash_map::Entry, HashSet}, io::Write, rc::Rc, sync::atomic::{AtomicU32, Ordering}, @@ -26,7 +26,7 @@ use tracing::{debug, trace}; use crate::{ h2::{ - body::{H2Body, PieceOrTrailers, StreamIncoming, StreamIncomingItem}, + body::{H2Body, IncomingMessagesResult, StreamIncoming}, encode::H2Encoder, types::{ BodyOutgoing, ConnState, H2ConnectionError, H2Event, H2EventPayload, H2RequestError, @@ -37,7 +37,10 @@ use crate::{ Headers, Method, Request, Responder, ServerDriver, }; -use super::{body::SinglePieceBody, types::H2RequestOrConnectionError}; +use super::{ + body::{ChunkPosition, SinglePieceBody}, + types::H2ErrorLevel, +}; pub const MAX_WINDOW_SIZE: i64 = u32::MAX as i64; @@ -65,7 +68,6 @@ pub async fn serve( let mut cx = ServerContext::new(driver.clone(), state, transport_w)?; cx.work(client_buf, transport_r).await?; - cx.transport_w.shutdown().await?; debug!("finished serving"); Ok(()) @@ -582,6 +584,8 @@ impl ServerContext { // tell the sender to stop sending chunks, which is not // possible if they all share the same ev_tx // TODO: make it possible to propagate errors to the sender + tracing::warn!(stream_id = %ev.stream_id, "ignoring event for stream, since we have no state for it"); + return Ok(()); } Some(outgoing) => outgoing, @@ -828,18 +832,17 @@ impl ServerContext { }); } incoming.capacity = next_cap; - // TODO: give back capacity to peer at some point - if incoming - .tx - .send(Ok(PieceOrTrailers::Piece(payload.into()))) - .await - .is_err() - { - debug!("TODO: The body is being ignored, we should reset the stream"); - } + let which = if frame.is_end_stream() { + ChunkPosition::Last + } else { + ChunkPosition::NotLast + }; - if flags.contains(DataFlags::EndStream) { + // TODO: give back capacity to peer at some point + if let Err(e) = incoming.write_chunk(payload.into(), which).await { + self.rst(frame.stream_id, e).await?; + } else if flags.contains(DataFlags::EndStream) { if let StreamState::Open { .. } = ss { let outgoing = match std::mem::take(ss) { StreamState::Open { outgoing, .. } => outgoing, @@ -976,9 +979,13 @@ impl ServerContext { .await { match e { - H2RequestOrConnectionError::ConnectionError(e) => return Err(e), - H2RequestOrConnectionError::RequestError(e) => { + H2ErrorLevel::Connection(e) => return Err(e), + H2ErrorLevel::Stream(e) => { + self.rst(frame.stream_id, e).await?; + } + H2ErrorLevel::Request(e) => { let stream_id = frame.stream_id; + tracing::debug!(?e, %stream_id, "Responding to stream with error"); // we need to insert it, otherwise `process_event` will ignore us // sending headers, etc. @@ -1063,11 +1070,10 @@ impl ServerContext { self.state.streams.len() ); match ss { - StreamState::Open { incoming, .. } - | StreamState::HalfClosedLocal { incoming, .. } => { - _ = incoming - .tx - .send(Err(H2StreamError::ReceivedRstStream.into())) + StreamState::Open { mut incoming, .. } + | StreamState::HalfClosedLocal { mut incoming, .. } => { + incoming + .send_error(eyre::eyre!("Received RST_STREAM from peer")) .await; } StreamState::HalfClosedRemote { .. } => { @@ -1311,7 +1317,7 @@ impl ServerContext { stream_id: StreamId, payload: Roll, rx: &mut mpsc::Receiver<(Frame, Roll)>, - ) -> Result<(), H2RequestOrConnectionError> { + ) -> Result<(), H2ErrorLevel> { let end_stream = flags.contains(HeadersFlags::EndStream); enum Data { @@ -1394,7 +1400,7 @@ impl ServerContext { // matter what: if we receive invalid headers for one request, we should still // keep reading the next request's headers, and that requires advancing the // huffman decoder's state, etc. - let mut req_error: Option = None; + let mut req_error: Option = None; let mut saw_regular_header = false; let on_header_pair = |key: Cow<[u8]>, value: Cow<[u8]>| { @@ -1413,20 +1419,16 @@ impl ServerContext { if &key[..1] == b":" { if saw_regular_header { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: - "bad request: All pseudo-header fields MUST appear in a field block before all regular field lines (RFC 9113, section 8.3)" - .into(), - }); + req_error = Some(H2StreamError::BadRequest( + "All pseudo-header fields MUST appear in a field block before all regular field lines (RFC 9113, section 8.3)" + )); return; } if matches!(headers_or_trailers, HeadersOrTrailers::Trailers) { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: Pseudo-header fields MUST NOT appear in a trailer section (RFC 9113, section 8.3)".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "Pseudo-header fields MUST NOT appear in a trailer section (RFC 9113, section 8.3)" + )); return; } @@ -1437,96 +1439,68 @@ impl ServerContext { let value: PieceStr = match Piece::from(value.to_vec()).to_str() { Ok(p) => p, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: invalid ':method' pseudo-header: not valid utf-8, so certainly not a valid method like POST, GET, OPTIONS, CONNECT, PROPFIND, etc.".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "invalid ':method' pseudo-header: not valid utf-8, so certainly not a valid method like POST, GET, OPTIONS, CONNECT, PROPFIND, etc.", + )); return; } }; if method.replace(Method::from(value)).is_some() { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: duplicate ':method' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9114, section 8.3.1)".into(), - }); + req_error = Some(H2StreamError::BadRequest("duplicate ':method' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9114, section 8.3.1)")); } } b"scheme" => { let value: PieceStr = match Piece::from(value.to_vec()).to_str() { Ok(p) => p, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: invalid ':scheme' pseudo-header: not valid utf-8".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "invalid ':scheme' pseudo-header: not valid utf-8", + )); return; } }; if scheme.replace(value.parse().unwrap()).is_some() { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: duplicate ':scheme' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9113, section 8.3.1)" - .into(), - }); + req_error = Some(H2StreamError::BadRequest("duplicate ':scheme' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9113, section 8.3.1)")); } } b"path" => { let value: PieceStr = match Piece::from(value.to_vec()).to_str() { Ok(val) => val, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: - "bad request: invalid ':path' pseudo-header (not valid utf-8, which is _certainly_ not a valid URI, as defined by RFC 3986, section 2. See also RFC 9113, section 8.3.1). " - .into(), - }); + req_error = Some(H2StreamError::BadRequest("invalid ':path' pseudo-header (not valid utf-8, which is _certainly_ not a valid URI, as defined by RFC 3986, section 2. See also RFC 9113, section 8.3.1). ")); return; } }; if path.replace(value).is_some() { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: duplicate ':path' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9113, section 8.3.1)".into(), - }); + req_error = Some(H2StreamError::BadRequest("duplicate ':path' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9113, section 8.3.1)")); } } b"authority" => { let value: PieceStr = match Piece::from(value.to_vec()).to_str() { Ok(p) => p, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: invalid ':authority' pseudo-header: not valid utf-8".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "invalid ':authority' pseudo-header: not valid utf-8", + )); return; } }; let value: Authority = match value.parse() { Ok(a) => a, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: invalid ':authority' pseudo-header: not a valid authority (which is to say: not a valid URI, see RFC 3986, section 3.2)".into(), - }); + req_error = Some(H2StreamError::BadRequest("invalid ':authority' pseudo-header: not a valid authority (which is to say: not a valid URI, see RFC 3986, section 3.2)")); return; } }; if authority.replace(value).is_some() { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: duplicate ':authority' pseudo-header. This one isn't technically forbidden by RFC 9113 section 8, but... it feels like it should be." - .into(), - }); + req_error = Some(H2StreamError::BadRequest("duplicate ':authority' pseudo-header. All HTTP/2 requests MUST include _exactly one_ valid value for the ':method', ':scheme', and ':path' pseudo-header fields, unless they are CONNECT requests (RFC 9113, section 8.3.1)")); } } _ => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: - "bad request: received invalid pseudo-header. the only defined pseudo-headers are: ':method', ':scheme', ':path', ':authority', ':status' (RFC 9113, section 8.1)" - .into(), - }); + req_error = Some(H2StreamError::BadRequest( + "received invalid pseudo-header. the only defined pseudo-headers are: ':method', ':scheme', ':path', ':authority', ':status' (RFC 9113, section 8.1)", + )); } } } else { @@ -1535,10 +1509,9 @@ impl ServerContext { let name = match HeaderName::from_bytes(&key[..]) { Ok(name) => name, Err(_) => { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: invalid header name. see RFC 9113, section 8.2.1, 'Field validity'".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "invalid header name. see RFC 9113, section 8.2.1, 'Field validity'", + )); return; } }; @@ -1547,10 +1520,9 @@ impl ServerContext { // Sections 5.1 and 5.5 of [HTTP] only needs an additional check that field // names do not include uppercase characters. if key.iter().any(|b: &u8| b.is_ascii_uppercase()) { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: A field name MUST NOT contain characters in the ranges 0x00-0x20, 0x41-0x5a, or 0x7f-0xff (all ranges inclusive). This specifically excludes all non-visible ASCII characters, ASCII SP (0x20), and uppercase characters ('A' to 'Z', ASCII 0x41 to 0x5a). See RFC9113, section 8.2.1, 'Field Validity'".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "A field name MUST NOT contain characters in the ranges 0x00-0x20, 0x41-0x5a, or 0x7f-0xff (all ranges inclusive). This specifically excludes all non-visible ASCII characters, ASCII SP (0x20), and uppercase characters ('A' to 'Z', ASCII 0x41 to 0x5a). See RFC9113, section 8.2.1, 'Field Validity'", + )); return; } @@ -1565,18 +1537,16 @@ impl ServerContext { || name == http::header::TRANSFER_ENCODING || name == http::header::UPGRADE { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: connection-specific headers are forbidden. see RFC 9113, section 8.1.2".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "connection-specific headers are forbidden. see RFC 9113, section 8.1.2", + )); return; } if name == http::header::TE && &value[..] != b"trailers" { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'te' did not contain 'trailers'. cf. RFC9113, Section 8.2.2: The only exception to this is the TE header field, which MAY be present in an HTTP/2 request; when it is, it MUST NOT contain any value other than 'trailers'".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "The only exception to this is the TE header field, which MAY be present in an HTTP/2 request; when it is, it MUST NOT contain any value other than 'trailers'. cf. RFC9113, Section 8.2.2", + )); return; } @@ -1588,10 +1558,9 @@ impl ServerContext { || last == Some(&b' ') || last == Some(&b'\x09') { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: A field value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, 0x20 or 0x09). (RFC 9113, section 8.2.1, 'Field validity')".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "A field value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, 0x20 or 0x09). (RFC 9113, section 8.2.1, 'Field validity')", + )); return; } @@ -1599,10 +1568,9 @@ impl ServerContext { .iter() .any(|&b| b == b'\r' || b == b'\n' || b == b'\0') { - req_error = Some(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. See RFC 9113, section 8.2.1, 'Field validity'".into(), - }); + req_error = Some(H2StreamError::BadRequest( + "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. See RFC 9113, section 8.2.1, 'Field validity'", + )); return; } @@ -1615,7 +1583,7 @@ impl ServerContext { Data::Single(payload) => { self.hpack_dec .decode_with_cb(&payload[..], on_header_pair) - .map_err(|e| H2RequestOrConnectionError::ConnectionError(e.into()))?; + .map_err(|e| H2ErrorLevel::Connection(e.into()))?; } Data::Multi(fragments) => { let total_len = fragments.iter().map(|f| f.len()).sum(); @@ -1628,7 +1596,7 @@ impl ServerContext { } self.hpack_dec .decode_with_cb(&payload[..], on_header_pair) - .map_err(|e| H2RequestOrConnectionError::ConnectionError(e.into()))?; + .map_err(|e| H2ErrorLevel::Connection(e.into()))?; } }; @@ -1650,31 +1618,28 @@ impl ServerContext { // RFC 9113, section 8.5 'The CONNECT method': The ":scheme" and ":path" // pseudo-header fields MUST be omitted. if scheme.is_some() { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: CONNECT method MUST NOT include ':scheme' pseudo-header".into(), - } + return Err(H2StreamError::BadRequest( + "CONNECT method MUST NOT include ':scheme' pseudo-header", + ) .into()); } if path.is_some() { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: CONNECT method MUST NOT include ':path' pseudo-header".into(), - } + return Err(H2StreamError::BadRequest( + "CONNECT method MUST NOT include ':path' pseudo-header", + ) .into()); } if authority.is_none() { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: CONNECT method MUST include ':authority' pseudo-header".into(), - } + return Err(H2StreamError::BadRequest( + "CONNECT method MUST include ':authority' pseudo-header", + ) .into()); } // well, also, we just don't support the `CONNECT` method. return Err(H2RequestError { status: StatusCode::NOT_IMPLEMENTED, - message: "bad request: CONNECT method is not supported".into(), + message: "CONNECT method is not supported".into(), } .into()); } @@ -1682,55 +1647,43 @@ impl ServerContext { method } None => { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: missing :method pseudo-header".into(), - } - .into()) + return Err( + H2StreamError::BadRequest("missing :method pseudo-header").into() + ) } }; let scheme = match scheme { Some(scheme) => scheme, None => { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: missing :scheme pseudo-header".into(), - } - .into()) + return Err( + H2StreamError::BadRequest("missing :scheme pseudo-header").into() + ); } }; - let path = - match path { - Some(path) => path, - None => return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: - "bad request: missing :path pseudo-header, cf. RFC9113, section 8.3.1: This pseudo-header field MUST NOT be empty for 'http' or 'https' URIs; 'http' or 'https' URIs that do not contain a path component MUST include a value of '/'." - .into(), - } - .into()), - }; + let path = match path { + Some(path) => path, + None => { + return Err( + H2StreamError::BadRequest("missing :path pseudo-header, cf. RFC9113, section 8.3.1: This pseudo-header field MUST NOT be empty for 'http' or 'https' URIs; 'http' or 'https' URIs that do not contain a path component MUST include a value of '/'.").into() + ); + } + }; if path.len() == 0 && (scheme == Scheme::HTTP || scheme == Scheme::HTTPS) { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: as per RFC9113, section 8.3.1, ':path' header value MUST NOT be empty for 'http' and 'https' URIs".into(), - } - .into()); + return Err(H2StreamError::BadRequest( + "as per RFC9113, section 8.3.1, ':path' header value MUST NOT be empty for 'http' and 'https' URIs", + ).into()); } let path_and_query: PathAndQuery = match path.parse() { Ok(p) => p, Err(_) => { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: - "bad request: ':path' header value is not a valid PathAndQuery" - .into(), - } - .into()) + return Err(H2StreamError::BadRequest( + "':path' header value is not a valid PathAndQuery", + ) + .into()); } }; @@ -1738,22 +1691,16 @@ impl ServerContext { Some(authority) => { // if there's a `host` header, it must match the `:authority` pseudo-header if let Some(host) = headers.get(header::HOST) { - let host = std::str::from_utf8(host).map_err(|_| H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'host' header value is not utf-8".into(), + let host = std::str::from_utf8(host).map_err(|_| { + H2StreamError::BadRequest("'host' header value is not utf-8") + })?; + let host_authority: Authority = host.parse().map_err(|_| { + H2StreamError::BadRequest("'host' header value is not a valid URI") })?; - let host_authority: Authority = - host.parse().map_err(|_| H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'host' header value is not a valid URI" - .into(), - })?; if host_authority != authority { - return Err(H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'host' header value does not match ':authority' pseudo-header value, cf. RFC9113, Section 8.3.1: A server SHOULD treat a request as malformed if it contains a Host header field that identifies an entity that differs from the entity in the ':authority' pseudo-header field".into(), - } - .into()); + return Err(H2StreamError::BadRequest( + "'host' header value does not match ':authority' pseudo-header value, cf. RFC9113, Section 8.3.1: A server SHOULD treat a request as malformed if it contains a Host header field that identifies an entity that differs from the entity in the ':authority' pseudo-header field" + ).into()); } } @@ -1761,16 +1708,12 @@ impl ServerContext { } None => match headers.get(header::HOST) { Some(host) => { - let host = std::str::from_utf8(host).map_err(|_| H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'host' header value is not utf-8".into(), + let host = std::str::from_utf8(host).map_err(|_| { + H2StreamError::BadRequest("'host' header value is not utf-8") + })?; + let authority: Authority = host.parse().map_err(|_| { + H2StreamError::BadRequest("'host' header value is not a valid URI") })?; - let authority: Authority = - host.parse().map_err(|_| H2RequestError { - status: StatusCode::BAD_REQUEST, - message: "bad request: 'host' header value is not a valid URI" - .into(), - })?; Some(authority) } None => None, @@ -1787,7 +1730,7 @@ impl ServerContext { Err(_) => { return Err(H2RequestError { status: StatusCode::BAD_REQUEST, - message: "bad request: invalid URI parts".into(), + message: "invalid URI parts".into(), } .into()) } @@ -1799,23 +1742,45 @@ impl ServerContext { version: Version::HTTP_2, headers, }; + let content_length: Option = match req + .headers + .get(http::header::CONTENT_LENGTH) + { + Some(len) => { + let len = std::str::from_utf8(len).map_err(|_| { + H2StreamError::BadRequest("content-length header value is not utf-8") + })?; + let len = len.parse().map_err(|_| { + H2StreamError::BadRequest( + "content-length header value is not a valid integer", + ) + })?; + Some(len) + } + None => { + if end_stream { + Some(0) + } else { + None + } + } + }; let responder = Responder::new(H2Encoder::new(stream_id, self.ev_tx.clone())); - let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? + let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? let req_body = H2Body { - // FIXME: that's not right. h2 requests can still specify - // a content-length - content_length: if end_stream { Some(0) } else { None }, + content_length, eof: end_stream, rx: piece_rx, }; - let incoming = StreamIncoming { - capacity: self.state.self_settings.initial_window_size as _, - tx: piece_tx, - }; + let incoming = StreamIncoming::new( + self.state.self_settings.initial_window_size as _, + content_length, + piece_tx, + ); let outgoing: StreamOutgoing = self.state.mk_stream_outgoing(); self.state.streams.insert( stream_id, @@ -1855,24 +1820,31 @@ impl ServerContext { }); } HeadersOrTrailers::Trailers => { - match self.state.streams.get_mut(&stream_id) { - Some(StreamState::Open { incoming, .. }) => { - if incoming - .tx - .send(Ok(PieceOrTrailers::Trailers(Box::new(headers)))) - .await - .is_err() - { - // the body is being ignored, but there's no point - // in resetting the - // stream since we just got the end of it + match self.state.streams.entry(stream_id) { + Entry::Occupied(mut slot) => match slot.get_mut() { + StreamState::Open { incoming, .. } => { + incoming.write_trailers(headers).await?; + + // set stream state to half closed remote. we do a little + // dance to avoid re-inserting. + let hcr = slot.insert(StreamState::Transition); + slot.insert(match hcr { + StreamState::Open { outgoing, .. } => { + StreamState::HalfClosedRemote { outgoing } + } + _ => unreachable!(), + }); } - } - _ => { + _ => { + unreachable!("stream state should be open when we receive trailers") + } + }, + Entry::Vacant(_) => { + // we received trailers for a stream that doesn't exist + // anymore, ignore them unreachable!("stream state should be open when we receive trailers") } } - self.state.streams.remove(&stream_id); } } diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index 23e2794d..a359bff7 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -279,12 +279,15 @@ impl BodyOutgoing { /// An error that may either indicate the peer is misbehaving /// or just a bad request from the client. #[derive(Debug, thiserror::Error)] -pub(crate) enum H2RequestOrConnectionError { +pub(crate) enum H2ErrorLevel { #[error("connection error: {0}")] - ConnectionError(#[from] H2ConnectionError), + Connection(#[from] H2ConnectionError), + + #[error("stream error: {0}")] + Stream(#[from] H2StreamError), #[error("request error: {0}")] - RequestError(#[from] H2RequestError), + Request(#[from] H2RequestError), } /// The client done goofed, we're returning 4xx most likely @@ -456,22 +459,21 @@ impl H2ConnectionError { #[derive(Debug, thiserror::Error)] pub(crate) enum H2StreamError { - #[allow(dead_code)] #[error("received {data_length} bytes in data frames but content-length announced {content_length} bytes")] DataLengthDoesNotMatchContentLength { data_length: u64, content_length: u64, }, + #[error("overflow while calculating content length")] + OverflowWhileCalculatingContentLength, + #[error("refused stream (would exceed max concurrent streams)")] RefusedStream, #[error("trailers must have EndStream flag set")] TrailersNotEndStream, - #[error("received RST_STREAM frame")] - ReceivedRstStream, - #[error("received PRIORITY frame with invalid size")] InvalidPriorityFrameSize { frame_size: u32 }, @@ -483,6 +485,12 @@ pub(crate) enum H2StreamError { #[error("received WINDOW_UPDATE that made the window size overflow")] WindowUpdateOverflow, + + #[error("bad request: {0}")] + BadRequest(&'static str), + + #[error("stream reset")] + Cancel, } impl H2StreamError { @@ -491,6 +499,7 @@ impl H2StreamError { use KnownErrorCode as Code; match self { + Cancel => Code::Cancel, // stream closed error StreamClosed => Code::StreamClosed, // stream refused error diff --git a/crates/fluke/tests/httpwg2.rs b/crates/fluke/tests/httpwg2.rs new file mode 100644 index 00000000..18bd0cc5 --- /dev/null +++ b/crates/fluke/tests/httpwg2.rs @@ -0,0 +1,143 @@ +// TODO: remove me, this was only used for a video + +use std::rc::Rc; + +use fluke::{Body, BodyChunk, Encoder, ExpectResponseHeaders, Responder, Response, ResponseDone}; +use fluke_buffet::{IntoHalves, PipeRead, PipeWrite, ReadOwned, RollMut, WriteOwned}; +use http::StatusCode; +use tracing::Level; +use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::SubscriberInitExt}; + +/// Note: this will not work with `cargo test`, since it sets up process-level +/// globals. But it will work with `cargo nextest`, and that's what fluke is +/// standardizing on. +pub(crate) fn setup_tracing_and_error_reporting() { + color_eyre::install().unwrap(); + + let targets = if let Ok(rust_log) = std::env::var("RUST_LOG") { + rust_log.parse::().unwrap() + } else { + Targets::new() + .with_default(Level::INFO) + .with_target("fluke", Level::DEBUG) + .with_target("httpwg", Level::DEBUG) + .with_target("want", Level::INFO) + }; + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(true) + .with_file(false) + .with_line_number(false) + .without_time(); + + tracing_subscriber::registry() + .with(targets) + .with(fmt_layer) + .init(); +} + +struct TestDriver; + +impl fluke::ServerDriver for TestDriver { + async fn handle( + &self, + _req: fluke::Request, + req_body: &mut impl Body, + mut res: Responder, + ) -> eyre::Result> { + // if the client sent `expect: 100-continue`, we must send a 100 status code + if let Some(h) = _req.headers.get(http::header::EXPECT) { + if &h[..] == b"100-continue" { + res.write_interim_response(Response { + status: StatusCode::CONTINUE, + ..Default::default() + }) + .await?; + } + } + + // then read the full request body + let mut req_body_len = 0; + loop { + let chunk = req_body.next_chunk().await?; + match chunk { + BodyChunk::Done { trailers } => { + // yey + if let Some(trailers) = trailers { + tracing::debug!(trailers_len = %trailers.len(), "received trailers"); + } + break; + } + BodyChunk::Chunk(chunk) => { + req_body_len += chunk.len(); + } + } + } + tracing::debug!(%req_body_len, "read request body"); + + let mut res = res + .write_final_response(Response { + status: StatusCode::OK, + ..Default::default() + }) + .await?; + + res.write_chunk("it's less dire to lose, than to lose oneself".into()) + .await?; + + let res = res.finish_body(None).await?; + + Ok(res) + } +} + +pub struct TwoHalves(W, R); +impl IntoHalves for TwoHalves { + type Read = R; + type Write = W; + + fn into_halves(self) -> (Self::Read, Self::Write) { + (self.1, self.0) + } +} + +pub fn start_server() -> httpwg::Conn> { + let (server_write, client_read) = fluke::buffet::pipe(); + let (client_write, server_read) = fluke::buffet::pipe(); + + let serve_fut = async move { + let server_conf = Rc::new(fluke::h2::ServerConf { + ..Default::default() + }); + + let client_buf = RollMut::alloc()?; + let driver = Rc::new(TestDriver); + let io = (server_read, server_write); + fluke::h2::serve(io, server_conf, client_buf, driver).await?; + tracing::debug!("http/2 server done"); + Ok::<_, eyre::Report>(()) + }; + + fluke_buffet::spawn(async move { + serve_fut.await.unwrap(); + }); + + let config = Rc::new(httpwg::Config::default()); + httpwg::Conn::new(config, TwoHalves(client_write, client_read)) +} + +#[test] +fn rfc9113_3_starting_http2_sends_client_connection_preface() { + crate::setup_tracing_and_error_reporting(); + + fluke_buffet::start(async move { + let conn = crate::start_server(); + + #[rustfmt::skip] + httpwg + ::rfc9113 + ::_8_expressing_http_semantics_in_http2 + ::sends_headers_frame_without_path(conn) + .await.unwrap(); + }); +} diff --git a/crates/httpwg-cli/Cargo.toml b/crates/httpwg-cli/Cargo.toml new file mode 100644 index 00000000..e3fff068 --- /dev/null +++ b/crates/httpwg-cli/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "httpwg-cli" +version = "0.1.0" +edition = "2021" + +[dependencies] +color-eyre = "0.6.3" +eyre = "0.6.12" +fluke-buffet = { version = "0.2.0", path = "../fluke-buffet" } +httpwg = { version = "0.1.1", path = "../httpwg" } +tokio = { version = "1.38.0", features = ["time"] } +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.18" } diff --git a/crates/httpwg-cli/src/main.rs b/crates/httpwg-cli/src/main.rs new file mode 100644 index 00000000..574f8dc7 --- /dev/null +++ b/crates/httpwg-cli/src/main.rs @@ -0,0 +1,135 @@ +use std::{ + collections::HashMap, future::Future, net::SocketAddr, pin::Pin, rc::Rc, time::Duration, +}; + +use fluke_buffet::{net::TcpStream, IntoHalves}; +use httpwg::{rfc9113, Config, Conn}; +use tracing::Level; +use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::SubscriberInitExt}; + +fn main() { + setup_tracing_and_error_reporting(); + + let cat = catalog::(); + + // filter is set to the first argument, if there is a first argument + let filter = std::env::args().nth(1); + + fluke_buffet::start(async move { + // now run it! establish a new TCP connection for every case, and run the test. + let addr = "localhost:8000"; + let conf = Rc::new(Config { + timeout: Duration::from_secs(3), + ..Default::default() + }); + + let local_set = tokio::task::LocalSet::new(); + + for (rfc, sections) in cat { + for (section, tests) in sections { + for (test, boxed_test) in tests { + let test_name = format!("{rfc} :: {section} :: {test}"); + if let Some(filter) = &filter { + if !test_name.contains(filter) { + println!("Skipping test: {}", test_name); + continue; + } + } + + let addr: SocketAddr = addr.parse().unwrap(); + let stream = + tokio::time::timeout(Duration::from_millis(250), TcpStream::connect(addr)) + .await + .unwrap() + .unwrap(); + let conn = Conn::new(conf.clone(), stream); + let test = async move { + println!("๐Ÿ”ท Running test: {}", test_name); + boxed_test(conn).await.unwrap(); + println!("โœ… Test passed: {}", test_name); + }; + local_set.spawn_local(test); + } + } + } + + local_set.await; + }); +} + +type Catalog = + HashMap<&'static str, HashMap<&'static str, HashMap<&'static str, BoxedTest>>>; + +#[allow(unused)] +fn print_catalog(cat: &Catalog) { + for (rfc, sections) in cat { + println!("๐Ÿ“• {}", rfc); + for (section, tests) in sections { + println!(" ๐Ÿ”ท {}", section); + for test in tests.keys() { + println!(" ๐Ÿ“„ {}", test); + } + } + } +} + +fn setup_tracing_and_error_reporting() { + color_eyre::install().unwrap(); + + let targets = if let Ok(rust_log) = std::env::var("RUST_LOG") { + rust_log.parse::().unwrap() + } else { + Targets::new() + .with_default(Level::INFO) + .with_target("fluke", Level::DEBUG) + .with_target("httpwg", Level::DEBUG) + .with_target("want", Level::INFO) + }; + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(true) + .with_file(false) + .with_line_number(false) + .without_time(); + + tracing_subscriber::registry() + .with(targets) + .with(fmt_layer) + .init(); +} +type BoxedTest = Box) -> Pin>>>>; + +pub fn catalog( +) -> HashMap<&'static str, HashMap<&'static str, HashMap<&'static str, BoxedTest>>> { + let mut rfcs: HashMap< + &'static str, + HashMap<&'static str, HashMap<&'static str, BoxedTest>>, + > = Default::default(); + + { + let mut sections: HashMap<&'static str, _> = Default::default(); + + { + use rfc9113::_8_expressing_http_semantics_in_http2 as s; + let mut section8: HashMap<&'static str, BoxedTest> = Default::default(); + section8.insert( + "client sends push promise frame", + Box::new(|conn: Conn| Box::pin(s::client_sends_push_promise_frame(conn))), + ); + section8.insert( + "sends connect with scheme", + Box::new(|conn: Conn| Box::pin(s::sends_connect_with_scheme(conn))), + ); + section8.insert( + "sends connect with path", + Box::new(|conn: Conn| Box::pin(s::sends_connect_with_path(conn))), + ); + + sections.insert("Section 8: Expressing HTTP Semantics in HTTP/2", section8); + } + + rfcs.insert("RFC 9113", sections); + } + + rfcs +} diff --git a/crates/httpwg-macros/src/lib.rs b/crates/httpwg-macros/src/lib.rs index 72e259b3..a60ab22e 100644 --- a/crates/httpwg-macros/src/lib.rs +++ b/crates/httpwg-macros/src/lib.rs @@ -761,6 +761,18 @@ use __group::sends_second_headers_frame_without_end_stream as test; $body } +#[test] +fn sends_headers_frame_with_incorrect_content_length_single_data_frame() { +use __group::sends_headers_frame_with_incorrect_content_length_single_data_frame as test; +$body +} + +#[test] +fn sends_headers_frame_with_incorrect_content_length_multiple_data_frames() { +use __group::sends_headers_frame_with_incorrect_content_length_multiple_data_frames as test; +$body +} + /// A field name MUST NOT contain characters in the ranges 0x00-0x20, 0x41-0x5a, /// or 0x7f-0xff (all ranges inclusive). This specifically excludes all /// non-visible ASCII characters, ASCII SP (0x20), and uppercase characters ('A' diff --git a/crates/httpwg/src/lib.rs b/crates/httpwg/src/lib.rs index 48f625dc..cdbf56c1 100644 --- a/crates/httpwg/src/lib.rs +++ b/crates/httpwg/src/lib.rs @@ -238,79 +238,116 @@ impl Conn { let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(1); let mut eof = false; - let recv_fut = async move { - let mut res_buf = RollMut::alloc()?; - 'read: loop { - if !eof { - res_buf.reserve()?; - let res; - (res, res_buf) = res_buf.read_into(16384, &mut r).await; - let n = res?; - if n == 0 { - debug!("reached EOF"); - eof = true; - } else { - trace!(%n, "read bytes (reading frame header)"); - } - } - - if eof && res_buf.is_empty() { - break 'read; - } - match Frame::parse(res_buf.filled()) { - Ok((rest, frame)) => { - res_buf.keep(rest); - debug!("< {frame:?}"); + let recv_fut = { + let config = config.clone(); + async move { + let mut res_buf = RollMut::alloc()?; + 'read: loop { + trace!("'read loop"); + + match Frame::parse(res_buf.filled()) { + Ok((rest, frame)) => { + res_buf.keep(rest); + debug!("< {frame:?}"); + + // read frame payload + let frame_len = frame.len as usize; + trace!(?frame_len, "reserving memory"); + res_buf.reserve_at_least(frame_len)?; + + let deadline = Instant::now() + config.timeout; + trace!(?frame_len, ?deadline, "reading"); + + while res_buf.len() < frame_len { + let res; + (res, res_buf) = match tokio::time::timeout_at( + deadline, + res_buf.read_into(16384, &mut r), + ) + .await + { + Ok(res) => res, + Err(_) => { + debug!(?frame_len, "timed out reading frame payload"); + break 'read; + } + }; + let n = res?; + trace!(%n, len = %res_buf.len(), "read bytes (reading frame payload)"); + + if n == 0 { + eof = true; + if res_buf.len() < frame_len { + panic!( + "peer frame header, then incomplete payload, then hung up" + ) + } + } + } - // read frame payload - let frame_len = frame.len as usize; - res_buf.reserve_at_least(frame_len)?; + let payload = if frame_len == 0 { + Roll::empty() + } else { + res_buf.take_at_most(frame_len).unwrap() + }; + assert_eq!(payload.len(), frame_len); + + trace!(%frame_len, "got frame payload"); + if ev_tx.send(Ev::Frame { frame, payload }).await.is_err() { + // I guess we stopped consuming frames, sure. + break 'read; + } + } + Err(nom::Err::Incomplete(_)) => { + if eof { + if res_buf.is_empty() { + // all good, that's eof! + break 'read; + } else { + panic!( + "peer sent incomplete frame header then hung up (buf len: {})", + res_buf.len() + ) + } + } - while res_buf.len() < frame_len { + trace!("reserving"); + res_buf.reserve()?; let res; - (res, res_buf) = res_buf.read_into(16384, &mut r).await; + trace!("re-filling buffer"); + let deadline = Instant::now() + config.timeout; + (res, res_buf) = match tokio::time::timeout_at( + deadline, + res_buf.read_into(16384, &mut r), + ) + .await + { + Ok(res) => res, + Err(_) => { + debug!("timed out reading frame header"); + break 'read; + } + }; let n = res?; - trace!(%n, len = %res_buf.len(), "read bytes (reading frame payload)"); - if n == 0 { + debug!("reached EOF"); eof = true; - if res_buf.len() < frame_len { - panic!( - "peer frame header, then incomplete payload, then hung up" - ) - } + } else { + trace!(%n, "read bytes (reading frame header)"); } - } - - let payload = if frame_len == 0 { - Roll::empty() - } else { - res_buf.take_at_most(frame_len).unwrap() - }; - assert_eq!(payload.len(), frame_len); - trace!(%frame_len, "got frame payload"); - ev_tx.send(Ev::Frame { frame, payload }).await.unwrap(); - } - Err(nom::Err::Incomplete(_)) => { - if eof { - panic!( - "peer sent incomplete frame header then hung up (buf len: {})", - res_buf.len() - ) + continue; + } + Err(nom::Err::Failure(err) | nom::Err::Error(err)) => { + debug!(?err, "got parse error"); + break; } - - continue; - } - Err(nom::Err::Failure(err) | nom::Err::Error(err)) => { - debug!(?err, "got parse error"); - break; } } - } - Ok::<_, eyre::Report>(()) + Ok::<_, eyre::Report>(()) + } }; fluke_buffet::spawn(async move { recv_fut.await.unwrap() }); @@ -930,6 +967,23 @@ impl Conn { Ok(()) } + + async fn send_req_and_expect_stream_rst( + &mut self, + stream_id: StreamId, + headers: &Headers, + ) -> eyre::Result<()> { + self.encode_and_write_headers( + stream_id, + HeadersFlags::EndHeaders | HeadersFlags::EndStream, + headers, + ) + .await?; + + self.verify_stream_error(ErrorC::ProtocolError).await?; + + Ok(()) + } } /// Parameters for tests diff --git a/crates/httpwg/src/rfc9113/_8_expressing_http_semantics_in_http2.rs b/crates/httpwg/src/rfc9113/_8_expressing_http_semantics_in_http2.rs index 7c43286b..c51161a2 100644 --- a/crates/httpwg/src/rfc9113/_8_expressing_http_semantics_in_http2.rs +++ b/crates/httpwg/src/rfc9113/_8_expressing_http_semantics_in_http2.rs @@ -35,6 +35,63 @@ pub async fn sends_second_headers_frame_without_end_stream( Ok(()) } +//--- Section 8.1.1: Malformed Messages + +// A request or response that includes message content can include a +// content-length header field. A request or response is also malformed if the +// value of a content-length header field does not equal the sum of the DATA +// frame payload lengths that form the content, unless the message is defined as +// having no content. For example, 204 or 304 responses contain no content, as +// does the response to a HEAD request. A response that is defined to have no +// content, as described in Section 6.4.1 of [HTTP], MAY have a non-zero +// content-length header field, even though no content is included in DATA +// frames. +// +// Intermediaries that process HTTP requests or responses (i.e., any +// intermediary not acting as a tunnel) MUST NOT forward a malformed request or +// response. Malformed requests or responses that are detected MUST be treated +// as a stream error (Section 5.4.2) of type PROTOCOL_ERROR. + +pub async fn sends_headers_frame_with_incorrect_content_length_single_data_frame( + mut conn: Conn, +) -> eyre::Result<()> { + let stream_id = StreamId(1); + conn.handshake().await?; + + let mut headers = conn.common_headers("POST"); + headers.append("content-length", "10"); + let block_fragment = conn.encode_headers(&headers); + conn.write_headers(stream_id, HeadersFlags::EndHeaders, block_fragment?) + .await?; + conn.write_data(stream_id, true, b"test").await?; + + conn.verify_stream_error(ErrorC::ProtocolError).await?; + + Ok(()) +} + +pub async fn sends_headers_frame_with_incorrect_content_length_multiple_data_frames< + IO: IntoHalves, +>( + mut conn: Conn, +) -> eyre::Result<()> { + let stream_id = StreamId(1); + conn.handshake().await?; + + let mut headers = conn.common_headers("POST"); + headers.append("content-length", "10"); + let block_fragment = conn.encode_headers(&headers); + conn.write_headers(stream_id, HeadersFlags::EndHeaders, block_fragment?) + .await?; + conn.write_data(stream_id, false, b"te").await?; + conn.write_data(stream_id, false, b"st").await?; + conn.write_data(stream_id, true, b"ing").await?; + + conn.verify_stream_error(ErrorC::ProtocolError).await?; + + Ok(()) +} + //--- Section 8.2.1: Field Validity /// A field name MUST NOT contain characters in the ranges 0x00-0x20, 0x41-0x5a, @@ -53,7 +110,7 @@ pub async fn sends_headers_frame_with_uppercase_field_name( let mut headers = conn.common_headers("POST"); headers.append("UPPERCASE", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -75,7 +132,7 @@ pub async fn sends_headers_frame_with_space_in_field_name( let mut headers = conn.common_headers("POST"); headers.append("space force", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -97,7 +154,7 @@ pub async fn sends_headers_frame_with_non_visible_ascii( let mut headers = conn.common_headers("POST"); headers.append("\x01invalid", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -119,7 +176,7 @@ pub async fn sends_headers_frame_with_del_character( let mut headers = conn.common_headers("POST"); headers.append("\x7Finvalid", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -141,7 +198,7 @@ pub async fn sends_headers_frame_with_non_ascii_character( let mut headers = conn.common_headers("POST"); headers.append("invรกlid", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -162,7 +219,7 @@ pub async fn sends_headers_frame_with_colon_in_field_name( let mut headers = conn.common_headers("POST"); headers.append("invalid:field", "oh no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -182,7 +239,7 @@ pub async fn sends_headers_frame_with_lf_in_field_value( let mut headers = conn.common_headers("POST"); headers.append("invalid-value", "oh\nno"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -202,7 +259,7 @@ pub async fn sends_headers_frame_with_cr_in_field_value( let mut headers = conn.common_headers("POST"); headers.append("invalid-value", "oh\rno"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -222,7 +279,7 @@ pub async fn sends_headers_frame_with_nul_in_field_value( let mut headers = conn.common_headers("POST"); headers.append("invalid-value", "oh\0no"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -242,7 +299,7 @@ pub async fn sends_headers_frame_with_leading_space_in_field_value( let mut headers = conn.common_headers("POST"); headers.append("connection", "keep-alive"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -309,7 +366,7 @@ pub async fn sends_headers_frame_with_proxy_connection_header( let mut headers = conn.common_headers("POST"); headers.append("proxy-connection", "keep-alive"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -332,7 +389,7 @@ pub async fn sends_headers_frame_with_keep_alive_header( let mut headers = conn.common_headers("POST"); headers.append("keep-alive", "timeout=5"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -355,7 +412,7 @@ pub async fn sends_headers_frame_with_transfer_encoding_header( let mut headers = conn.common_headers("POST"); headers.append("transfer-encoding", "chunked"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -378,7 +435,7 @@ pub async fn sends_headers_frame_with_upgrade_header( let mut headers = conn.common_headers("POST"); headers.append("upgrade", "h2c"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -410,7 +467,7 @@ pub async fn sends_headers_frame_with_te_not_trailers( let mut headers = conn.common_headers("POST"); headers.append("te", "not-trailers"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -433,7 +490,7 @@ pub async fn sends_headers_frame_with_response_pseudo_header( let mut headers = conn.common_headers("POST"); headers.append(":status", "200"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -463,14 +520,7 @@ pub async fn sends_headers_frame_with_pseudo_header_in_trailer( ) .await?; - // wait for headers frame, expect 400 status - let (frame, payload) = conn.wait_for_frame(FrameT::Headers).await.unwrap(); - assert!(frame.is_end_headers(), "this test makes that assumption"); - let headers = conn.decode_headers(payload.into())?; - let status = headers.get_first(&":status".into()).unwrap(); - let status = std::str::from_utf8(status)?; - let status = status.parse::().unwrap(); - assert_eq!(status, 400); + conn.verify_stream_error(ErrorC::ProtocolError).await?; Ok(()) } @@ -488,7 +538,7 @@ pub async fn sends_headers_frame_with_duplicate_pseudo_headers( headers.prepend(":method", "POST"); headers.prepend(":method", "POST"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -513,7 +563,7 @@ pub async fn sends_headers_frame_with_mismatched_host_authority( "host", format!("{}.different", conn.config.host).into_bytes(), ); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -540,7 +590,7 @@ pub async fn sends_headers_frame_with_empty_path_component( let mut headers = conn.common_headers("POST"); headers.replace(":path", ""); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -558,7 +608,7 @@ pub async fn sends_headers_frame_without_method( let mut headers = conn.common_headers("POST"); headers.remove(&":method".into()); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -571,7 +621,7 @@ pub async fn sends_headers_frame_without_scheme( let mut headers = conn.common_headers("POST"); headers.remove(&":scheme".into()); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -584,7 +634,7 @@ pub async fn sends_headers_frame_without_path( let mut headers = conn.common_headers("POST"); headers.remove(&":path".into()); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -693,7 +743,7 @@ pub async fn sends_connect_with_scheme(mut conn: Conn) -> ey headers.append(":method", "CONNECT"); headers.append(":scheme", "https"); headers.append(":authority", "example.com:443"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -706,7 +756,7 @@ pub async fn sends_connect_with_path(mut conn: Conn) -> eyre headers.append(":method", "CONNECT"); headers.append(":path", "/"); headers.append(":authority", "example.com:443"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -719,7 +769,7 @@ pub async fn sends_connect_without_authority( let mut headers = Headers::default(); headers.append(":method", "CONNECT"); - conn.send_req_and_expect_status(StreamId(1), &headers, 400) + conn.send_req_and_expect_stream_rst(StreamId(1), &headers) .await?; Ok(()) @@ -738,7 +788,7 @@ pub async fn sends_headers_frame_with_pseudo_headers_after_regular_headers