From 8c5fe3aba7e910fcdf6b1e7d7338418d6bb9faf0 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:35:54 +0100 Subject: [PATCH 01/17] wip: Get rid of write task Closes #123 --- crates/fluke/src/h2/read.rs | 18 ++++++++---- crates/fluke/src/h2/server.rs | 52 ++--------------------------------- 2 files changed, 16 insertions(+), 54 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index e59c9971..d7fffd3b 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -7,7 +7,7 @@ use std::{ use enumflags2::BitFlags; use eyre::Context; use fluke_buffet::{Piece, PieceStr, Roll, RollMut}; -use fluke_maybe_uring::io::ReadOwned; +use fluke_maybe_uring::io::{ReadOwned, WriteOwned}; use http::{ header, uri::{Authority, PathAndQuery, Scheme}, @@ -37,28 +37,36 @@ use super::{ }; /// Reads and processes h2 frames from the client. -pub(crate) struct H2ReadContext { +pub(crate) struct H2ReadContext { driver: Rc, - ev_tx: mpsc::Sender, state: ConnState, hpack_dec: fluke_hpack::Decoder<'static>, /// Whether we've received a GOAWAY frame. pub goaway_recv: bool, + + transport_w: W, + + ev_tx: mpsc::Sender, + ev_rx: mpsc::Receiver, } -impl H2ReadContext { - pub(crate) fn new(driver: Rc, ev_tx: mpsc::Sender, state: ConnState) -> Self { +impl H2ReadContext { + pub(crate) fn new(driver: Rc, state: ConnState, transport_w: W) -> Self { let mut hpack_dec = fluke_hpack::Decoder::new(); hpack_dec .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); + let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); + Self { driver, ev_tx, + ev_rx, state, hpack_dec, goaway_recv: false, + transport_w, } } diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 32a0dbc5..7dd879db 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -68,55 +68,9 @@ pub async fn serve( debug!("sent settings frame"); } - let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); - - let h2_read_cx = H2ReadContext::new(driver.clone(), ev_tx, state); - - { - let mut read_task = std::pin::pin!(h2_read_cx.read_loop(client_buf, transport_r)); - let mut write_task = - std::pin::pin!(super::write::h2_write_loop(ev_rx, transport_w, out_scratch)); - - tokio::select! { - res = &mut read_task => { - match res { - Err(e) => { - return Err(e.wrap_err("h2 read (finished first)")); - } - Ok(()) => { - debug!("read task finished, waiting on write task now"); - let res = write_task.await; - match res { - Err(e) => { - if is_peer_gone(&e) { - debug!(%e, "write task failed with peer gone"); - } else { - return Err(e.wrap_err("h2 write (finished second)")); - } - } - Ok(()) => { - debug!("write task finished okay"); - } - } - } - } - }, - res = write_task.as_mut() => { - match res { - Err(e) => { - if is_peer_gone(&e) { - debug!(%e, "write task failed with peer gone"); - } else { - return Err(e.wrap_err("h2 write (finished first)")); - } - } - Ok(()) => { - debug!("write task finished, giving up read task"); - } - } - }, - }; - }; + H2ReadContext::new(driver.clone(), state, transport_w) + .read_loop(client_buf, transport_r) + .await?; debug!("finished serving"); Ok(()) From 21421973c5e3893bd1cade60ef429d2df547316d Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:41:12 +0100 Subject: [PATCH 02/17] Give out_scratch to h2readcontext --- crates/fluke/src/h2/read.rs | 34 ++++++++++++++++++++++++---------- crates/fluke/src/h2/server.rs | 17 +++-------------- crates/fluke/src/h2/types.rs | 11 +++++++++++ 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index d7fffd3b..6dfe3baf 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -41,6 +41,8 @@ pub(crate) struct H2ReadContext { driver: Rc, state: ConnState, hpack_dec: fluke_hpack::Decoder<'static>, + hpack_enc: fluke_hpack::Encoder<'static>, + out_scratch: RollMut, /// Whether we've received a GOAWAY frame. pub goaway_recv: bool, @@ -52,11 +54,18 @@ pub(crate) struct H2ReadContext { } impl H2ReadContext { - pub(crate) fn new(driver: Rc, state: ConnState, transport_w: W) -> Self { + pub(crate) fn new( + driver: Rc, + state: ConnState, + transport_w: W, + out_scratch: RollMut, + ) -> Self { let mut hpack_dec = fluke_hpack::Decoder::new(); hpack_dec .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); + let hpack_enc = fluke_hpack::Encoder::new(); + let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); Self { @@ -65,12 +74,15 @@ impl H2ReadContext { ev_rx, state, hpack_dec, + hpack_enc, + out_scratch, goaway_recv: false, transport_w, } } - pub(crate) async fn read_loop( + /// Reads and process h2 frames from the client. + pub(crate) async fn work( mut self, client_buf: RollMut, transport_r: impl ReadOwned, @@ -85,13 +97,17 @@ impl H2ReadContext { // FIXME: the process_task should update this let max_frame_size = Rc::new(AtomicU32::new(self.state.self_settings.max_frame_size)); - let mut io_task = - std::pin::pin!(Self::io_loop(client_buf, transport_r, tx, max_frame_size)); + let mut deframe_task = std::pin::pin!(Self::deframe_loop( + client_buf, + transport_r, + tx, + max_frame_size + )); let mut process_task = std::pin::pin!(self.process_loop(rx)); tokio::select! { - res = &mut io_task => { - debug!(?res, "h2 io task finished"); + res = &mut deframe_task => { + debug!(?res, "h2 deframe task finished"); if let Err(H2ConnectionError::ReadError(e)) = res { let mut should_ignore_err = false; @@ -119,8 +135,6 @@ impl H2ReadContext { if let Err(err) = res { goaway_err = Some(err); } - - // probably don't need to wait for the io task there } } } @@ -142,7 +156,7 @@ impl H2ReadContext { Ok(()) } - async fn io_loop( + async fn deframe_loop( mut client_buf: RollMut, mut transport_r: impl ReadOwned, tx: mpsc::Sender<(Frame, Roll)>, @@ -230,7 +244,7 @@ impl H2ReadContext { } if tx.send((frame, payload)).await.is_err() { - debug!("h2 io loop: receiver dropped, closing connection"); + debug!("h2 deframer: receiver dropped, closing connection"); return Ok(()); } } diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 7dd879db..938503dc 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -52,7 +52,7 @@ pub async fn serve( let mut out_scratch = RollMut::alloc()?; - // we have to send a settings frame + // we have to send an initial settings frame { let payload_roll = state.self_settings.into_roll(&mut out_scratch)?; let frame_roll = Frame::new( @@ -68,21 +68,10 @@ pub async fn serve( debug!("sent settings frame"); } - H2ReadContext::new(driver.clone(), state, transport_w) - .read_loop(client_buf, transport_r) + H2ReadContext::new(driver.clone(), state, transport_w, out_scratch) + .work(client_buf, transport_r) .await?; debug!("finished serving"); Ok(()) } - -fn is_peer_gone(e: &eyre::Report) -> bool { - if let Some(io_error) = e.root_cause().downcast_ref::() { - matches!( - io_error.kind(), - std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset - ) - } else { - false - } -} diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index 1ae01e95..7dcf3735 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -267,3 +267,14 @@ impl fmt::Debug for H2EventPayload { #[derive(thiserror::Error, Debug)] #[error("the peer closed the connection unexpectedly")] pub(crate) struct ConnectionClosed; + +fn is_peer_gone(e: &eyre::Report) -> bool { + if let Some(io_error) = e.root_cause().downcast_ref::() { + matches!( + io_error.kind(), + std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset + ) + } else { + false + } +} From c92841c523dcd48d1c54676ad8a47081feb86738 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:46:33 +0100 Subject: [PATCH 03/17] Introduce write_frame --- crates/fluke/src/h2/read.rs | 505 +++++++++++++++++++---------------- crates/fluke/src/h2/types.rs | 3 + 2 files changed, 272 insertions(+), 236 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 6dfe3baf..78a39147 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -256,152 +256,91 @@ impl H2ReadContext { &mut self, mut rx: mpsc::Receiver<(Frame, Roll)>, ) -> Result<(), H2ConnectionError> { - 'process_frames: loop { - let (frame, mut payload) = match rx.recv().await { - Some(t) => t, - None => { - debug!("h2 process task: peer hung up"); - break; - } - }; - - match frame.frame_type { - FrameType::Data(flags) => { - let ss = self.state.streams.get_mut(&frame.stream_id).ok_or( - H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }, - )?; - - match ss { - StreamState::Open(tx) => { - if tx - .send(Ok(PieceOrTrailers::Piece(payload.into()))) - .await - .is_err() - { - warn!( - "TODO: The body is being ignored, we should reset the stream" - ); - } - - if flags.contains(DataFlags::EndStream) { - *ss = StreamState::HalfClosedRemote; - } - } - StreamState::HalfClosedRemote => { - debug!( - stream_id = %frame.stream_id, - "Received data for closed stream" - ); - self.send_rst(frame.stream_id, H2StreamError::StreamClosed) - .await + loop { + tokio::select! { + _ = self.ev_rx.recv() => { + todo!("handle conn events") + }, + maybe_frame = rx.recv() => { + if let Some((frame, payload)) = maybe_frame { + let res = self.process_frame(frame, payload, &mut rx).await; + if let Err(e) = res { + return Err(e); } + } else { + debug!("h2 process task: peer hung up"); + break; } } - FrameType::Headers(flags) => { - if flags.contains(HeadersFlags::Priority) { - let pri_spec; - (payload, pri_spec) = PrioritySpec::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - debug!(exclusive = %pri_spec.exclusive, stream_dependency = ?pri_spec.stream_dependency, weight = %pri_spec.weight, "received priority, exclusive"); - - if pri_spec.stream_dependency == frame.stream_id { - return Err(H2ConnectionError::HeadersInvalidPriority { - stream_id: frame.stream_id, - }); - } - } + } + } - let headers_or_trailers; - let mode; + Ok(()) + } - match self.state.streams.get_mut(&frame.stream_id) { - None => { - headers_or_trailers = HeadersOrTrailers::Headers; - debug!( - stream_id = %frame.stream_id, - last_stream_id = %self.state.last_stream_id, - next_stream_count = %self.state.streams.len() + 1, - "Receiving headers", - ); + async fn write_frame(&mut self, frame: Frame, payload: Roll) -> Result<(), H2ConnectionError> { + let frame_roll = frame.into_roll(&mut self.out_scratch)?; - if frame.stream_id.is_server_initiated() { - return Err(H2ConnectionError::ClientSidShouldBeOdd); - } + if payload.is_empty() { + self.transport_w + .write_all(frame_roll) + .await + .map_err(H2ConnectionError::WriteError)?; + } else { + self.transport_w + .writev_all(&[frame_roll, payload]) + .await + .map_err(H2ConnectionError::WriteError)?; + } - if frame.stream_id <= self.state.last_stream_id { - debug!( - frame_stream_id = %frame.stream_id, - last_stream_id = %self.state.last_stream_id, - "Received headers for invalid stream ID" - ); - - // this stream may have existed, but it no longer does: - return Err(H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }); - } else { - // TODO: if we're shutting down, ignore streams higher - // than the last one we accepted. - - let max_concurrent_streams = - self.state.self_settings.max_concurrent_streams; - let num_streams_if_accept = self.state.streams.len() + 1; - if num_streams_if_accept > max_concurrent_streams as _ { - // reset the stream, indicating we refused it - self.send_rst(frame.stream_id, H2StreamError::RefusedStream) - .await; - - // but we still need to skip over any continuation frames - mode = ReadHeadersMode::Skip; - } else { - self.state.last_stream_id = frame.stream_id; - mode = ReadHeadersMode::Process; - } - } - } - Some(StreamState::Open(_)) => { - headers_or_trailers = HeadersOrTrailers::Trailers; - debug!("Receiving trailers for stream {}", frame.stream_id); + Ok(()) + } - if flags.contains(HeadersFlags::EndStream) { - // good, that's what we expect - mode = ReadHeadersMode::Process; - } else { - // ignore trailers, we're not accepting the stream - mode = ReadHeadersMode::Skip; + async fn process_frame( + &mut self, + frame: Frame, + payload: Roll, + rx: &mut mpsc::Receiver<(Frame, Roll)>, + ) -> Result<(), H2ConnectionError> { + match frame.frame_type { + FrameType::Data(flags) => { + let ss = self.state.streams.get_mut(&frame.stream_id).ok_or( + H2ConnectionError::StreamClosed { + stream_id: frame.stream_id, + }, + )?; - self.send_rst(frame.stream_id, H2StreamError::TrailersNotEndStream) - .await - } + match ss { + StreamState::Open(tx) => { + if tx + .send(Ok(PieceOrTrailers::Piece(payload.into()))) + .await + .is_err() + { + warn!("TODO: The body is being ignored, we should reset the stream"); } - Some(StreamState::HalfClosedRemote) => { - return Err(H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }); + + if flags.contains(DataFlags::EndStream) { + *ss = StreamState::HalfClosedRemote; } } - - self.read_headers( - headers_or_trailers, - mode, - flags, - frame.stream_id, - payload, - &mut rx, - ) - .await?; + StreamState::HalfClosedRemote => { + debug!( + stream_id = %frame.stream_id, + "Received data for closed stream" + ); + self.send_rst(frame.stream_id, H2StreamError::StreamClosed) + .await + } } - FrameType::Priority => { - let pri_spec = match PrioritySpec::parse(payload) { - Ok((_rest, pri_spec)) => pri_spec, - Err(e) => { - todo!("handle connection error: invalid priority frame {e}") - } - }; - debug!(?pri_spec, "received priority frame"); + } + FrameType::Headers(flags) => { + if flags.contains(HeadersFlags::Priority) { + let pri_spec; + (payload, pri_spec) = PrioritySpec::parse(payload) + .finish() + .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + debug!(exclusive = %pri_spec.exclusive, stream_dependency = ?pri_spec.stream_dependency, weight = %pri_spec.weight, "received priority, exclusive"); if pri_spec.stream_dependency == frame.stream_id { return Err(H2ConnectionError::HeadersInvalidPriority { @@ -409,138 +348,232 @@ impl H2ReadContext { }); } } - FrameType::RstStream => match self.state.streams.remove(&frame.stream_id) { + + let headers_or_trailers; + let mode; + + match self.state.streams.get_mut(&frame.stream_id) { None => { - return Err(H2ConnectionError::RstStreamForUnknownStream { - stream_id: frame.stream_id, - }) - } - Some(ss) => match ss { - StreamState::Open(body_tx) => { - _ = body_tx - .send(Err(H2StreamError::ReceivedRstStream.into())) - .await; - } - StreamState::HalfClosedRemote => { - // good + headers_or_trailers = HeadersOrTrailers::Headers; + debug!( + stream_id = %frame.stream_id, + last_stream_id = %self.state.last_stream_id, + next_stream_count = %self.state.streams.len() + 1, + "Receiving headers", + ); + + if frame.stream_id.is_server_initiated() { + return Err(H2ConnectionError::ClientSidShouldBeOdd); } - }, - }, - FrameType::Settings(s) => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::SettingsWithNonZeroStreamId { - stream_id: frame.stream_id, - }); - } - if s.contains(SettingsFlags::Ack) { - debug!("Peer has acknowledged our settings, cool"); - if !payload.is_empty() { - return Err(H2ConnectionError::SettingsAckWithPayload { - len: payload.len() as _, + if frame.stream_id <= self.state.last_stream_id { + debug!( + frame_stream_id = %frame.stream_id, + last_stream_id = %self.state.last_stream_id, + "Received headers for invalid stream ID" + ); + + // this stream may have existed, but it no longer does: + return Err(H2ConnectionError::StreamClosed { + stream_id: frame.stream_id, }); + } else { + // TODO: if we're shutting down, ignore streams higher + // than the last one we accepted. + + let max_concurrent_streams = + self.state.self_settings.max_concurrent_streams; + let num_streams_if_accept = self.state.streams.len() + 1; + if num_streams_if_accept > max_concurrent_streams as _ { + // reset the stream, indicating we refused it + self.send_rst(frame.stream_id, H2StreamError::RefusedStream) + .await; + + // but we still need to skip over any continuation frames + mode = ReadHeadersMode::Skip; + } else { + self.state.last_stream_id = frame.stream_id; + mode = ReadHeadersMode::Process; + } } - } else { - let (_, settings) = Settings::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - let new_max_header_table_size = settings.header_table_size; - debug!(?settings, "Received settings"); - self.state.peer_settings = settings; - - if self - .ev_tx - .send(H2ConnEvent::AcknowledgeSettings { - new_max_header_table_size, - }) - .await - .is_err() - { - return Err(eyre::eyre!( - "could not send H2 acknowledge settings event" - ) - .into()); + } + Some(StreamState::Open(_)) => { + headers_or_trailers = HeadersOrTrailers::Trailers; + debug!("Receiving trailers for stream {}", frame.stream_id); + + if flags.contains(HeadersFlags::EndStream) { + // good, that's what we expect + mode = ReadHeadersMode::Process; + } else { + // ignore trailers, we're not accepting the stream + mode = ReadHeadersMode::Skip; + + self.send_rst(frame.stream_id, H2StreamError::TrailersNotEndStream) + .await } } - } - FrameType::PushPromise => { - return Err(H2ConnectionError::ClientSentPushPromise); - } - FrameType::Ping(flags) => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::PingFrameWithNonZeroStreamId { + Some(StreamState::HalfClosedRemote) => { + return Err(H2ConnectionError::StreamClosed { stream_id: frame.stream_id, }); } + } - if frame.len != 8 { - return Err(H2ConnectionError::PingFrameInvalidLength { len: frame.len }); + self.read_headers( + headers_or_trailers, + mode, + flags, + frame.stream_id, + payload, + &mut rx, + ) + .await?; + } + FrameType::Priority => { + let pri_spec = match PrioritySpec::parse(payload) { + Ok((_rest, pri_spec)) => pri_spec, + Err(e) => { + todo!("handle connection error: invalid priority frame {e}") } + }; + debug!(?pri_spec, "received priority frame"); - if flags.contains(PingFlags::Ack) { - // TODO: check that payload matches the one we sent? - continue 'process_frames; + if pri_spec.stream_dependency == frame.stream_id { + return Err(H2ConnectionError::HeadersInvalidPriority { + stream_id: frame.stream_id, + }); + } + } + FrameType::RstStream => match self.state.streams.remove(&frame.stream_id) { + None => { + return Err(H2ConnectionError::RstStreamForUnknownStream { + stream_id: frame.stream_id, + }) + } + Some(ss) => match ss { + StreamState::Open(body_tx) => { + _ = body_tx + .send(Err(H2StreamError::ReceivedRstStream.into())) + .await; } - - if self.ev_tx.send(H2ConnEvent::Ping(payload)).await.is_err() { - return Err(eyre::eyre!("could not send H2 ping event").into()); + StreamState::HalfClosedRemote => { + // good } + }, + }, + FrameType::Settings(s) => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::SettingsWithNonZeroStreamId { + stream_id: frame.stream_id, + }); } - FrameType::GoAway => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::GoAwayWithNonZeroStreamId { - stream_id: frame.stream_id, - }); - } - - self.goaway_recv = true; - // TODO: this should probably have other effects than setting - // this flag. - } - FrameType::WindowUpdate => { - if payload.len() != 4 { - return Err(H2ConnectionError::WindowUpdateInvalidLength { + if s.contains(SettingsFlags::Ack) { + debug!("Peer has acknowledged our settings, cool"); + if !payload.is_empty() { + return Err(H2ConnectionError::SettingsAckWithPayload { len: payload.len() as _, }); } - - let increment; - (_, (_, increment)) = parse_reserved_and_u31(payload) + } else { + let (_, settings) = Settings::parse(payload) .finish() .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - - if increment == 0 { - return Err(H2ConnectionError::WindowUpdateZeroIncrement); + let new_max_header_table_size = settings.header_table_size; + debug!(?settings, "Received settings"); + self.state.peer_settings = settings; + + if self + .ev_tx + .send(H2ConnEvent::AcknowledgeSettings { + new_max_header_table_size, + }) + .await + .is_err() + { + return Err( + eyre::eyre!("could not send H2 acknowledge settings event").into() + ); } + } + } + FrameType::PushPromise => { + return Err(H2ConnectionError::ClientSentPushPromise); + } + FrameType::Ping(flags) => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::PingFrameWithNonZeroStreamId { + stream_id: frame.stream_id, + }); + } - if frame.stream_id == StreamId::CONNECTION { - debug!("TODO: ignoring connection-wide window update"); - } else { - match self.state.streams.get_mut(&frame.stream_id) { - None => { - return Err(H2ConnectionError::WindowUpdateForUnknownStream { - stream_id: frame.stream_id, - }); - } - Some(_ss) => { - debug!("TODO: handle window update for stream {}", frame.stream_id) - } - } - } + if frame.len != 8 { + return Err(H2ConnectionError::PingFrameInvalidLength { len: frame.len }); + } + + if flags.contains(PingFlags::Ack) { + // TODO: check that payload matches the one we sent? + continue 'process_frames; } - FrameType::Continuation(_flags) => { - return Err(H2ConnectionError::UnexpectedContinuationFrame { + + if self.ev_tx.send(H2ConnEvent::Ping(payload)).await.is_err() { + return Err(eyre::eyre!("could not send H2 ping event").into()); + } + } + FrameType::GoAway => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::GoAwayWithNonZeroStreamId { stream_id: frame.stream_id, }); } - FrameType::Unknown(ft) => { - trace!( - "ignoring unknown frame with type 0x{:x}, flags 0x{:x}", - ft.ty, - ft.flags - ); + + self.goaway_recv = true; + + // TODO: this should probably have other effects than setting + // this flag. + } + FrameType::WindowUpdate => { + if payload.len() != 4 { + return Err(H2ConnectionError::WindowUpdateInvalidLength { + len: payload.len() as _, + }); + } + + let increment; + (_, (_, increment)) = parse_reserved_and_u31(payload) + .finish() + .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + + if increment == 0 { + return Err(H2ConnectionError::WindowUpdateZeroIncrement); } + + if frame.stream_id == StreamId::CONNECTION { + debug!("TODO: ignoring connection-wide window update"); + } else { + match self.state.streams.get_mut(&frame.stream_id) { + None => { + return Err(H2ConnectionError::WindowUpdateForUnknownStream { + stream_id: frame.stream_id, + }); + } + Some(_ss) => { + debug!("TODO: handle window update for stream {}", frame.stream_id) + } + } + } + } + FrameType::Continuation(_flags) => { + return Err(H2ConnectionError::UnexpectedContinuationFrame { + stream_id: frame.stream_id, + }); + } + FrameType::Unknown(ft) => { + trace!( + "ignoring unknown frame with type 0x{:x}, flags 0x{:x}", + ft.ty, + ft.flags + ); } } diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index 7dcf3735..b1f55274 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -135,6 +135,9 @@ pub(crate) enum H2ConnectionError { #[error("error reading/parsing H2 frame: {0:?}")] ReadError(eyre::Report), + #[error("error writing H2 frame: {0:?}")] + WriteError(std::io::Error), + #[error("received rst frame for unknown stream")] RstStreamForUnknownStream { stream_id: StreamId }, From b197c61691a57e3e27a82d48eb398ae96fed1002 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:49:08 +0100 Subject: [PATCH 04/17] Move ping to write_frame --- crates/fluke/src/h2/read.rs | 10 ++++++---- crates/fluke/src/h2/types.rs | 1 - crates/fluke/src/h2/write.rs | 14 -------------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 78a39147..9330a28f 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -513,12 +513,14 @@ impl H2ReadContext { if flags.contains(PingFlags::Ack) { // TODO: check that payload matches the one we sent? - continue 'process_frames; + return Ok(()); } - if self.ev_tx.send(H2ConnEvent::Ping(payload)).await.is_err() { - return Err(eyre::eyre!("could not send H2 ping event").into()); - } + // send pong frame + let flags = PingFlags::Ack.into(); + let frame = Frame::new(FrameType::Ping(flags), StreamId::CONNECTION) + .with_len(payload.len() as u32); + self.write_frame(frame, payload).await?; } FrameType::GoAway => { if frame.stream_id != StreamId::CONNECTION { diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index b1f55274..cd2fadab 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -230,7 +230,6 @@ pub(crate) enum HeadersOrTrailers { } pub(crate) enum H2ConnEvent { - Ping(Roll), ServerEvent(H2Event), AcknowledgeSettings { new_max_header_table_size: u32, diff --git a/crates/fluke/src/h2/write.rs b/crates/fluke/src/h2/write.rs index aef86958..5d1f6d92 100644 --- a/crates/fluke/src/h2/write.rs +++ b/crates/fluke/src/h2/write.rs @@ -96,20 +96,6 @@ pub(crate) async fn h2_write_loop( } } } - H2ConnEvent::Ping(payload) => { - // send pong frame - let flags = PingFlags::Ack.into(); - let frame = Frame::new(FrameType::Ping(flags), StreamId::CONNECTION) - .with_len(payload.len() as u32); - transport_w - .writev_all( - PieceList::default() - .with(frame.into_roll(&mut out_scratch)?) - .with(payload), - ) - .await - .wrap_err("writing pong")?; - } H2ConnEvent::GoAway { err, last_stream_id, From cf2ba8e406f2216717fcff9124f5504317925357 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:53:59 +0100 Subject: [PATCH 05/17] Move goaway writing away from write --- crates/fluke/src/h2/read.rs | 41 ++++++++++++++++++++++++------------ crates/fluke/src/h2/types.rs | 4 ---- crates/fluke/src/h2/write.rs | 36 ------------------------------- 3 files changed, 27 insertions(+), 54 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 9330a28f..08600f86 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + io::Write, rc::Rc, sync::atomic::{AtomicU32, Ordering}, }; @@ -140,17 +141,29 @@ impl H2ReadContext { } if let Some(err) = goaway_err { - if self - .ev_tx - .send(H2ConnEvent::GoAway { - err, - last_stream_id: self.state.last_stream_id, - }) - .await - .is_err() - { - debug!("error sending goaway"); - } + let error_code = err.as_known_error_code(); + debug!("Connection error: {err} ({err:?}) (code {error_code:?})"); + + // let's put something useful in debug data + let additional_debug_data = format!("{err}").into_bytes(); + + // TODO: figure out graceful shutdown: this would involve sending a goaway + // before this point, and processing all the connections we've accepted + debug!(last_stream_id = %self.state.last_stream_id, ?error_code, "Sending GoAway"); + let payload = + self.out_scratch + .put_to_roll(8 + additional_debug_data.len(), |mut slice| { + use byteorder::{BigEndian, WriteBytesExt}; + // TODO: do we ever need to write the reserved bit? + slice.write_u32::(self.state.last_stream_id.0)?; + slice.write_u32::(error_code.repr())?; + slice.write_all(additional_debug_data.as_slice())?; + + Ok(()) + })?; + + let frame = Frame::new(FrameType::GoAway, StreamId::CONNECTION); + self.write_frame(frame, payload).await?; } Ok(()) @@ -299,8 +312,8 @@ impl H2ReadContext { async fn process_frame( &mut self, frame: Frame, - payload: Roll, - rx: &mut mpsc::Receiver<(Frame, Roll)>, + mut payload: Roll, + mut rx: &mut mpsc::Receiver<(Frame, Roll)>, ) -> Result<(), H2ConnectionError> { match frame.frame_type { FrameType::Data(flags) => { @@ -425,7 +438,7 @@ impl H2ReadContext { flags, frame.stream_id, payload, - &mut rx, + rx, ) .await?; } diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index cd2fadab..e72faa46 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -234,10 +234,6 @@ pub(crate) enum H2ConnEvent { AcknowledgeSettings { new_max_header_table_size: u32, }, - GoAway { - err: H2ConnectionError, - last_stream_id: StreamId, - }, RstStream { stream_id: StreamId, error_code: KnownErrorCode, diff --git a/crates/fluke/src/h2/write.rs b/crates/fluke/src/h2/write.rs index 5d1f6d92..d9e16c0f 100644 --- a/crates/fluke/src/h2/write.rs +++ b/crates/fluke/src/h2/write.rs @@ -96,42 +96,6 @@ pub(crate) async fn h2_write_loop( } } } - H2ConnEvent::GoAway { - err, - last_stream_id, - } => { - let error_code = err.as_known_error_code(); - debug!("Connection error: {err} ({err:?}) (code {error_code:?})"); - - // let's put something useful in debug data - let additional_debug_data = format!("{err}").into_bytes(); - - debug!(%last_stream_id, ?error_code, "Sending GoAway"); - let header = out_scratch.put_to_roll(8, |mut slice| { - use byteorder::{BigEndian, WriteBytesExt}; - // TODO: do we ever need to write the reserved bit? - slice.write_u32::(last_stream_id.0)?; - slice.write_u32::(error_code.repr())?; - - Ok(()) - })?; - - let frame = Frame::new(FrameType::GoAway, StreamId::CONNECTION).with_len( - (header.len() + additional_debug_data.len()) - .try_into() - .unwrap(), - ); - - transport_w - .writev_all( - PieceList::default() - .with(frame.into_roll(&mut out_scratch)?) - .with(header) - .with(additional_debug_data), - ) - .await - .wrap_err("writing goaway")?; - } H2ConnEvent::RstStream { stream_id, error_code, From ba2973a37ba0d48360c93828372aedfc0e87969b Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:57:51 +0100 Subject: [PATCH 06/17] Deprecate H2ConnEvent::RstStream --- crates/fluke/src/h2/read.rs | 48 +++++++++++++++++++---------------- crates/fluke/src/h2/server.rs | 2 +- crates/fluke/src/h2/types.rs | 10 ++------ 3 files changed, 29 insertions(+), 31 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 08600f86..c348cb93 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -5,6 +5,7 @@ use std::{ sync::atomic::{AtomicU32, Ordering}, }; +use byteorder::{BigEndian, WriteBytesExt}; use enumflags2::BitFlags; use eyre::Context; use fluke_buffet::{Piece, PieceStr, Roll, RollMut}; @@ -144,7 +145,7 @@ impl H2ReadContext { let error_code = err.as_known_error_code(); debug!("Connection error: {err} ({err:?}) (code {error_code:?})"); - // let's put something useful in debug data + // TODO: don't heap-allocate here let additional_debug_data = format!("{err}").into_bytes(); // TODO: figure out graceful shutdown: this would involve sending a goaway @@ -153,8 +154,6 @@ impl H2ReadContext { let payload = self.out_scratch .put_to_roll(8 + additional_debug_data.len(), |mut slice| { - use byteorder::{BigEndian, WriteBytesExt}; - // TODO: do we ever need to write the reserved bit? slice.write_u32::(self.state.last_stream_id.0)?; slice.write_u32::(error_code.repr())?; slice.write_all(additional_debug_data.as_slice())?; @@ -313,7 +312,7 @@ impl H2ReadContext { &mut self, frame: Frame, mut payload: Roll, - mut rx: &mut mpsc::Receiver<(Frame, Roll)>, + rx: &mut mpsc::Receiver<(Frame, Roll)>, ) -> Result<(), H2ConnectionError> { match frame.frame_type { FrameType::Data(flags) => { @@ -342,8 +341,8 @@ impl H2ReadContext { stream_id = %frame.stream_id, "Received data for closed stream" ); - self.send_rst(frame.stream_id, H2StreamError::StreamClosed) - .await + self.rst(frame.stream_id, H2StreamError::StreamClosed) + .await?; } } } @@ -399,7 +398,7 @@ impl H2ReadContext { let num_streams_if_accept = self.state.streams.len() + 1; if num_streams_if_accept > max_concurrent_streams as _ { // reset the stream, indicating we refused it - self.send_rst(frame.stream_id, H2StreamError::RefusedStream) + self.rst(frame.stream_id, H2StreamError::RefusedStream) .await; // but we still need to skip over any continuation frames @@ -421,8 +420,8 @@ impl H2ReadContext { // ignore trailers, we're not accepting the stream mode = ReadHeadersMode::Skip; - self.send_rst(frame.stream_id, H2StreamError::TrailersNotEndStream) - .await + self.rst(frame.stream_id, H2StreamError::TrailersNotEndStream) + .await?; } } Some(StreamState::HalfClosedRemote) => { @@ -595,23 +594,28 @@ impl H2ReadContext { Ok(()) } - async fn send_rst(&mut self, stream_id: StreamId, e: H2StreamError) { + /// Send a RST_STREAM frame to the peer. + async fn rst( + &mut self, + stream_id: StreamId, + e: H2StreamError, + ) -> Result<(), H2ConnectionError> { + self.state.streams.remove(&stream_id); + let error_code = e.as_known_error_code(); debug!("Sending rst because: {e} (known error code: {error_code:?})"); - if self - .ev_tx - .send(H2ConnEvent::RstStream { - stream_id, - error_code, - }) - .await - .is_err() - { - debug!("error sending rst"); - } + debug!(%stream_id, ?error_code, "Sending RstStream"); + let payload = self.out_scratch.put_to_roll(4, |mut slice| { + slice.write_u32::(error_code.repr())?; + Ok(()) + })?; - self.state.streams.remove(&stream_id); + let frame = Frame::new(FrameType::RstStream, stream_id) + .with_len((payload.len()).try_into().unwrap()); + self.write_frame(frame, payload); + + Ok(()) } async fn read_headers( diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 938503dc..b7793bdf 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -6,7 +6,7 @@ use crate::{ h2::{ parse::{self, Frame, FrameType, StreamId}, read::H2ReadContext, - types::{ConnState, H2ConnEvent}, + types::ConnState, }, util::read_and_parse, ServerDriver, diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index e72faa46..82f303b9 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, fmt}; -use fluke_buffet::{Piece, Roll}; +use fluke_buffet::Piece; use crate::Response; @@ -231,13 +231,7 @@ pub(crate) enum HeadersOrTrailers { pub(crate) enum H2ConnEvent { ServerEvent(H2Event), - AcknowledgeSettings { - new_max_header_table_size: u32, - }, - RstStream { - stream_id: StreamId, - error_code: KnownErrorCode, - }, + AcknowledgeSettings { new_max_header_table_size: u32 }, } #[derive(Debug)] From 8c4c8b94e354c756bfc81185e5e780afb3dc8339 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 14:59:51 +0100 Subject: [PATCH 07/17] Move settings acknowledgement inside read --- crates/fluke/src/h2/read.rs | 21 ++++++++----------- crates/fluke/src/h2/types.rs | 1 - crates/fluke/src/h2/write.rs | 40 +----------------------------------- 3 files changed, 9 insertions(+), 53 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index c348cb93..708ed2e9 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -491,22 +491,17 @@ impl H2ReadContext { let (_, settings) = Settings::parse(payload) .finish() .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - let new_max_header_table_size = settings.header_table_size; + self.hpack_enc + .set_max_table_size(settings.header_table_size as usize); + debug!(?settings, "Received settings"); self.state.peer_settings = settings; - if self - .ev_tx - .send(H2ConnEvent::AcknowledgeSettings { - new_max_header_table_size, - }) - .await - .is_err() - { - return Err( - eyre::eyre!("could not send H2 acknowledge settings event").into() - ); - } + let frame = Frame::new( + FrameType::Settings(SettingsFlags::Ack.into()), + StreamId::CONNECTION, + ); + self.write_frame(frame, Roll::empty()).await?; } } FrameType::PushPromise => { diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index 82f303b9..dee6db1c 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -231,7 +231,6 @@ pub(crate) enum HeadersOrTrailers { pub(crate) enum H2ConnEvent { ServerEvent(H2Event), - AcknowledgeSettings { new_max_header_table_size: u32 }, } #[derive(Debug)] diff --git a/crates/fluke/src/h2/write.rs b/crates/fluke/src/h2/write.rs index d9e16c0f..8ab89e15 100644 --- a/crates/fluke/src/h2/write.rs +++ b/crates/fluke/src/h2/write.rs @@ -4,7 +4,7 @@ use tokio::sync::mpsc; use tracing::{debug, trace}; use crate::h2::{ - parse::{DataFlags, Frame, FrameType, HeadersFlags, PingFlags, SettingsFlags, StreamId}, + parse::{DataFlags, Frame, FrameType, HeadersFlags, SettingsFlags, StreamId}, types::{H2ConnEvent, H2EventPayload}, }; use fluke_buffet::{PieceList, RollMut}; @@ -21,21 +21,6 @@ pub(crate) async fn h2_write_loop( while let Some(ev) = ev_rx.recv().await { trace!("h2_write_loop: received H2 event"); match ev { - H2ConnEvent::AcknowledgeSettings { - new_max_header_table_size, - } => { - debug!("Acknowledging new settings"); - hpack_enc.set_max_table_size(new_max_header_table_size.try_into().unwrap()); - - let frame = Frame::new( - FrameType::Settings(SettingsFlags::Ack.into()), - StreamId::CONNECTION, - ); - transport_w - .write_all(frame.into_roll(&mut out_scratch)?) - .await - .wrap_err("writing acknowledge settings")?; - } H2ConnEvent::ServerEvent(ev) => { debug!(?ev, "Writing"); @@ -96,29 +81,6 @@ pub(crate) async fn h2_write_loop( } } } - H2ConnEvent::RstStream { - stream_id, - error_code, - } => { - debug!(%stream_id, ?error_code, "Sending RstStream"); - let header = out_scratch.put_to_roll(4, |mut slice| { - use byteorder::{BigEndian, WriteBytesExt}; - slice.write_u32::(error_code.repr())?; - Ok(()) - })?; - - let frame = Frame::new(FrameType::RstStream, stream_id) - .with_len((header.len()).try_into().unwrap()); - - transport_w - .writev_all( - PieceList::default() - .with(frame.into_roll(&mut out_scratch)?) - .with(header), - ) - .await - .wrap_err("writing rststream")?; - } } } From cc9af38d8a516a5f3fadad8eec18ff157c5da2c2 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:00:29 +0100 Subject: [PATCH 08/17] mhmh --- crates/fluke/src/h2/read.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 708ed2e9..98b62fc9 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -275,10 +275,7 @@ impl H2ReadContext { }, maybe_frame = rx.recv() => { if let Some((frame, payload)) = maybe_frame { - let res = self.process_frame(frame, payload, &mut rx).await; - if let Err(e) = res { - return Err(e); - } + self.process_frame(frame, payload, &mut rx).await?; } else { debug!("h2 process task: peer hung up"); break; @@ -399,7 +396,7 @@ impl H2ReadContext { if num_streams_if_accept > max_concurrent_streams as _ { // reset the stream, indicating we refused it self.rst(frame.stream_id, H2StreamError::RefusedStream) - .await; + .await?; // but we still need to skip over any continuation frames mode = ReadHeadersMode::Skip; @@ -608,7 +605,7 @@ impl H2ReadContext { let frame = Frame::new(FrameType::RstStream, stream_id) .with_len((payload.len()).try_into().unwrap()); - self.write_frame(frame, payload); + self.write_frame(frame, payload).await?; Ok(()) } From 004914e6b00f573f3c7d4716db7727e479bae665 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:01:57 +0100 Subject: [PATCH 09/17] Retire H2ConnEvent --- crates/fluke/src/h2/encode.rs | 10 ++-- crates/fluke/src/h2/read.rs | 8 +-- crates/fluke/src/h2/types.rs | 4 -- crates/fluke/src/h2/write.rs | 108 ++++++++++++++++------------------ 4 files changed, 61 insertions(+), 69 deletions(-) diff --git a/crates/fluke/src/h2/encode.rs b/crates/fluke/src/h2/encode.rs index f08332e3..b5671c8b 100644 --- a/crates/fluke/src/h2/encode.rs +++ b/crates/fluke/src/h2/encode.rs @@ -4,7 +4,7 @@ use tracing::warn; use super::{ parse::StreamId, - types::{H2ConnEvent, H2Event, H2EventPayload}, + types::{H2Event, H2EventPayload}, }; use crate::{h1::body::BodyWriteMode, Encoder, Response}; @@ -16,16 +16,16 @@ pub(crate) enum EncoderState { pub struct H2Encoder { pub(crate) stream_id: StreamId, - pub(crate) tx: mpsc::Sender, + pub(crate) tx: mpsc::Sender, pub(crate) state: EncoderState, } impl H2Encoder { - fn event(&self, payload: H2EventPayload) -> H2ConnEvent { - H2ConnEvent::ServerEvent(H2Event { + fn event(&self, payload: H2EventPayload) -> H2Event { + H2Event { payload, stream_id: self.stream_id, - }) + } } async fn send(&self, payload: H2EventPayload) -> eyre::Result<()> { diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 98b62fc9..28f2a370 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -35,7 +35,7 @@ use super::{ parse::{ ContinuationFlags, DataFlags, HeadersFlags, PingFlags, Settings, SettingsFlags, StreamId, }, - types::{ConnState, H2ConnEvent, H2StreamError, HeadersOrTrailers, StreamState}, + types::{ConnState, H2Event, H2StreamError, HeadersOrTrailers, StreamState}, }; /// Reads and processes h2 frames from the client. @@ -51,8 +51,8 @@ pub(crate) struct H2ReadContext { transport_w: W, - ev_tx: mpsc::Sender, - ev_rx: mpsc::Receiver, + ev_tx: mpsc::Sender, + ev_rx: mpsc::Receiver, } impl H2ReadContext { @@ -68,7 +68,7 @@ impl H2ReadContext { let hpack_enc = fluke_hpack::Encoder::new(); - let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); + let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); Self { driver, diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index dee6db1c..952e5962 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -229,10 +229,6 @@ pub(crate) enum HeadersOrTrailers { Trailers, } -pub(crate) enum H2ConnEvent { - ServerEvent(H2Event), -} - #[derive(Debug)] pub(crate) struct H2Event { pub(crate) stream_id: StreamId, diff --git a/crates/fluke/src/h2/write.rs b/crates/fluke/src/h2/write.rs index 8ab89e15..67ede04d 100644 --- a/crates/fluke/src/h2/write.rs +++ b/crates/fluke/src/h2/write.rs @@ -4,15 +4,15 @@ use tokio::sync::mpsc; use tracing::{debug, trace}; use crate::h2::{ - parse::{DataFlags, Frame, FrameType, HeadersFlags, SettingsFlags, StreamId}, - types::{H2ConnEvent, H2EventPayload}, + parse::{DataFlags, Frame, FrameType, HeadersFlags}, + types::{H2Event, H2EventPayload}, }; use fluke_buffet::{PieceList, RollMut}; use fluke_maybe_uring::io::WriteOwned; /// Write H2 frames to the transport, from a channel pub(crate) async fn h2_write_loop( - mut ev_rx: mpsc::Receiver, + mut ev_rx: mpsc::Receiver, mut transport_w: impl WriteOwned, mut out_scratch: RollMut, ) -> eyre::Result<()> { @@ -20,66 +20,62 @@ pub(crate) async fn h2_write_loop( while let Some(ev) = ev_rx.recv().await { trace!("h2_write_loop: received H2 event"); - match ev { - H2ConnEvent::ServerEvent(ev) => { - debug!(?ev, "Writing"); + debug!(?ev, "Writing"); - match ev.payload { - H2EventPayload::Headers(res) => { - let flags = HeadersFlags::EndHeaders; - let mut frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); + match ev.payload { + H2EventPayload::Headers(res) => { + let flags = HeadersFlags::EndHeaders; + let mut frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); - // TODO: don't allocate so much for headers. all `encode_into` - // wants is an `IntoIter`, we can definitely have a custom iterator - // that operates on all this instead of using a `Vec`. + // TODO: don't allocate so much for headers. all `encode_into` + // wants is an `IntoIter`, we can definitely have a custom iterator + // that operates on all this instead of using a `Vec`. - // TODO: limit header size - let mut headers: Vec<(&[u8], &[u8])> = vec![]; - headers.push((b":status", res.status.as_str().as_bytes())); - for (name, value) in res.headers.iter() { - if name == http::header::TRANSFER_ENCODING { - // do not set transfer-encoding: chunked when doing HTTP/2 - continue; - } - headers.push((name.as_str().as_bytes(), value)); - } + // TODO: limit header size + let mut headers: Vec<(&[u8], &[u8])> = vec![]; + headers.push((b":status", res.status.as_str().as_bytes())); + for (name, value) in res.headers.iter() { + if name == http::header::TRANSFER_ENCODING { + // do not set transfer-encoding: chunked when doing HTTP/2 + continue; + } + headers.push((name.as_str().as_bytes(), value)); + } - assert_eq!(out_scratch.len(), 0); - hpack_enc.encode_into(headers, &mut out_scratch)?; - let fragment_block = out_scratch.take_all(); + assert_eq!(out_scratch.len(), 0); + hpack_enc.encode_into(headers, &mut out_scratch)?; + let fragment_block = out_scratch.take_all(); - frame.len = fragment_block.len() as u32; - let frame_roll = frame.into_roll(&mut out_scratch)?; + frame.len = fragment_block.len() as u32; + let frame_roll = frame.into_roll(&mut out_scratch)?; - transport_w - .writev_all(PieceList::default().with(frame_roll).with(fragment_block)) - .await - .wrap_err("writing headers")?; - } - H2EventPayload::BodyChunk(chunk) => { - let flags = BitFlags::::default(); - let frame = Frame::new(FrameType::Data(flags), ev.stream_id) - .with_len(chunk.len().try_into().unwrap()); - let frame_roll = frame.into_roll(&mut out_scratch)?; - transport_w - .writev_all(PieceList::default().with(frame_roll).with(chunk)) - .await - .wrap_err("writing bodychunk")?; - } - H2EventPayload::BodyEnd => { - // FIXME: this should transition the stream to `Closed` - // state (or at the very least `HalfClosedLocal`). - // Either way, whoever owns the stream state should know - // about it, cf. https://github.com/hapsoc/fluke/issues/123 + transport_w + .writev_all(PieceList::default().with(frame_roll).with(fragment_block)) + .await + .wrap_err("writing headers")?; + } + H2EventPayload::BodyChunk(chunk) => { + let flags = BitFlags::::default(); + let frame = Frame::new(FrameType::Data(flags), ev.stream_id) + .with_len(chunk.len().try_into().unwrap()); + let frame_roll = frame.into_roll(&mut out_scratch)?; + transport_w + .writev_all(PieceList::default().with(frame_roll).with(chunk)) + .await + .wrap_err("writing bodychunk")?; + } + H2EventPayload::BodyEnd => { + // FIXME: this should transition the stream to `Closed` + // state (or at the very least `HalfClosedLocal`). + // Either way, whoever owns the stream state should know + // about it, cf. https://github.com/hapsoc/fluke/issues/123 - let flags = DataFlags::EndStream; - let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); - transport_w - .write_all(frame.into_roll(&mut out_scratch)?) - .await - .wrap_err("writing BodyEnd")?; - } - } + let flags = DataFlags::EndStream; + let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); + transport_w + .write_all(frame.into_roll(&mut out_scratch)?) + .await + .wrap_err("writing BodyEnd")?; } } } From ed2aa99f76b09df80a4ccb4fd7d7c9cf4966cba8 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:14:10 +0100 Subject: [PATCH 10/17] Move all writing out of write --- crates/fluke/src/h2/read.rs | 99 ++++++++++++++++++++++++++++++++---- crates/fluke/src/h2/types.rs | 9 +++- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index 28f2a370..f86b2aef 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -8,7 +8,7 @@ use std::{ use byteorder::{BigEndian, WriteBytesExt}; use enumflags2::BitFlags; use eyre::Context; -use fluke_buffet::{Piece, PieceStr, Roll, RollMut}; +use fluke_buffet::{Piece, PieceList, PieceStr, Roll, RollMut}; use fluke_maybe_uring::io::{ReadOwned, WriteOwned}; use http::{ header, @@ -35,7 +35,7 @@ use super::{ parse::{ ContinuationFlags, DataFlags, HeadersFlags, PingFlags, Settings, SettingsFlags, StreamId, }, - types::{ConnState, H2Event, H2StreamError, HeadersOrTrailers, StreamState}, + types::{ConnState, H2Event, H2EventPayload, H2StreamError, HeadersOrTrailers, StreamState}, }; /// Reads and processes h2 frames from the client. @@ -270,8 +270,11 @@ impl H2ReadContext { ) -> Result<(), H2ConnectionError> { loop { tokio::select! { - _ = self.ev_rx.recv() => { - todo!("handle conn events") + ev = self.ev_rx.recv() => { + match ev { + Some(ev) => self.handle_event(ev).await?, + None => unreachable!("the context owns a copy of the sender, and this method has &mut self, so the sender can't be dropped while this method is running"), + } }, maybe_frame = rx.recv() => { if let Some((frame, payload)) = maybe_frame { @@ -287,7 +290,74 @@ impl H2ReadContext { Ok(()) } - async fn write_frame(&mut self, frame: Frame, payload: Roll) -> Result<(), H2ConnectionError> { + async fn handle_event(&mut self, ev: H2Event) -> Result<(), H2ConnectionError> { + match ev.payload { + H2EventPayload::Headers(res) => { + let flags = HeadersFlags::EndHeaders; + let frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); + + // TODO: don't allocate so much for headers. all `encode_into` + // wants is an `IntoIter`, we can definitely have a custom iterator + // that operates on all this instead of using a `Vec`. + + // TODO: limit header size + let mut headers: Vec<(&[u8], &[u8])> = vec![]; + headers.push((b":status", res.status.as_str().as_bytes())); + for (name, value) in res.headers.iter() { + if name == http::header::TRANSFER_ENCODING { + // do not set transfer-encoding: chunked when doing HTTP/2 + continue; + } + headers.push((name.as_str().as_bytes(), value)); + } + + assert_eq!(self.out_scratch.len(), 0); + self.hpack_enc + .encode_into(headers, &mut self.out_scratch) + .map_err(H2ConnectionError::WriteError)?; + let payload = self.out_scratch.take_all(); + + self.write_frame(frame, payload).await?; + } + H2EventPayload::BodyChunk(chunk) => { + let flags = BitFlags::::default(); + let frame = Frame::new(FrameType::Data(flags), ev.stream_id); + + self.write_frame(frame, chunk).await?; + } + H2EventPayload::BodyEnd => { + // FIXME: this should transition the stream to `Closed` + // state (or at the very least `HalfClosedLocal`). + // Either way, whoever owns the stream state should know + // about it, cf. https://github.com/hapsoc/fluke/issues/123 + + let flags = DataFlags::EndStream; + let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); + self.write_frame(frame, Roll::empty()).await?; + } + } + + Ok(()) + } + + async fn write_frame( + &mut self, + frame: Frame, + payload: impl Into, + ) -> Result<(), H2ConnectionError> { + let payload = payload.into(); + + match &frame.frame_type { + FrameType::Data(headers) => { + if headers.contains(DataFlags::EndStream) { + // if the stream is open, this transitions to HalfClosedLocal + } + } + _ => { + // muffin. + } + } + let frame_roll = frame.into_roll(&mut self.out_scratch)?; if payload.is_empty() { @@ -297,7 +367,7 @@ impl H2ReadContext { .map_err(H2ConnectionError::WriteError)?; } else { self.transport_w - .writev_all(&[frame_roll, payload]) + .writev_all(PieceList::default().with(frame_roll).with(payload)) .await .map_err(H2ConnectionError::WriteError)?; } @@ -320,8 +390,8 @@ impl H2ReadContext { )?; match ss { - StreamState::Open(tx) => { - if tx + StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { + if body_tx .send(Ok(PieceOrTrailers::Piece(payload.into()))) .await .is_err() @@ -330,7 +400,13 @@ impl H2ReadContext { } if flags.contains(DataFlags::EndStream) { - *ss = StreamState::HalfClosedRemote; + // if we're HalfClosedLocal, this transitions to Closed + // otherwise, it transitions to HalfClosedRemote + if matches!(ss, StreamState::Open(_)) { + *ss = StreamState::HalfClosedRemote; + } else { + self.state.streams.remove(&frame.stream_id); + } } } StreamState::HalfClosedRemote => { @@ -406,7 +482,7 @@ impl H2ReadContext { } } } - Some(StreamState::Open(_)) => { + Some(StreamState::Open(_) | StreamState::HalfClosedLocal(_)) => { headers_or_trailers = HeadersOrTrailers::Trailers; debug!("Receiving trailers for stream {}", frame.stream_id); @@ -453,6 +529,7 @@ impl H2ReadContext { }); } } + // note: this always unconditionally transitions the stream to closed FrameType::RstStream => match self.state.streams.remove(&frame.stream_id) { None => { return Err(H2ConnectionError::RstStreamForUnknownStream { @@ -460,7 +537,7 @@ impl H2ReadContext { }) } Some(ss) => match ss { - StreamState::Open(body_tx) => { + StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { _ = body_tx .send(Err(H2StreamError::ReceivedRstStream.into())) .await; diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index 952e5962..d82d7450 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -68,9 +68,16 @@ impl Default for ConnState { // PP: PUSH_PROMISE frame (with implied CONTINUATION frames); state // transitions are for the promised stream pub(crate) enum StreamState { + // we have received full HEADERS Open(H2BodySender), + + // the peer has sent END_STREAM/RST_STREAM HalfClosedRemote, - // note: the "Closed" state is indicated by not having an entry in the map + + // we have sent END_STREAM/RST_STREAM + HalfClosedLocal(H2BodySender), + // + // Note: the "Closed" state is indicated by not having an entry in the map } #[derive(Debug, thiserror::Error)] From fbad3eab0d8a065fbf216148210f19c00a01b3c3 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:21:54 +0100 Subject: [PATCH 11/17] Remove write module altogether --- crates/fluke/src/h2/mod.rs | 1 - crates/fluke/src/h2/read.rs | 25 ++++++++++- crates/fluke/src/h2/write.rs | 84 ------------------------------------ 3 files changed, 24 insertions(+), 86 deletions(-) delete mode 100644 crates/fluke/src/h2/write.rs diff --git a/crates/fluke/src/h2/mod.rs b/crates/fluke/src/h2/mod.rs index ea99fa1f..0761ec56 100644 --- a/crates/fluke/src/h2/mod.rs +++ b/crates/fluke/src/h2/mod.rs @@ -10,4 +10,3 @@ mod body; mod encode; mod read; mod types; -mod write; diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs index f86b2aef..9cac6d49 100644 --- a/crates/fluke/src/h2/read.rs +++ b/crates/fluke/src/h2/read.rs @@ -49,6 +49,8 @@ pub(crate) struct H2ReadContext { /// Whether we've received a GOAWAY frame. pub goaway_recv: bool, + /// TODO: encapsulate into a framer, don't + /// allow direct access from context methods transport_w: W, ev_tx: mpsc::Sender, @@ -350,7 +352,28 @@ impl H2ReadContext { match &frame.frame_type { FrameType::Data(headers) => { if headers.contains(DataFlags::EndStream) { - // if the stream is open, this transitions to HalfClosedLocal + // if the stream is open, this transitions to HalfClosedLocal. + if let Some(ss) = self.state.streams.get_mut(&frame.stream_id) { + match ss { + StreamState::Open(_) => { + // transition through StreamState::HalfClosedRemote + // so we don't have to remove/re-insert. + let mut entry = StreamState::HalfClosedRemote; + std::mem::swap(&mut entry, ss); + + let body_tx = match entry { + StreamState::Open(body_tx) => body_tx, + _ => unreachable!(), + }; + + *ss = StreamState::HalfClosedLocal(body_tx); + } + _ => { + // transition to closed + self.state.streams.remove(&frame.stream_id); + } + } + } } } _ => { diff --git a/crates/fluke/src/h2/write.rs b/crates/fluke/src/h2/write.rs deleted file mode 100644 index 67ede04d..00000000 --- a/crates/fluke/src/h2/write.rs +++ /dev/null @@ -1,84 +0,0 @@ -use enumflags2::BitFlags; -use eyre::Context; -use tokio::sync::mpsc; -use tracing::{debug, trace}; - -use crate::h2::{ - parse::{DataFlags, Frame, FrameType, HeadersFlags}, - types::{H2Event, H2EventPayload}, -}; -use fluke_buffet::{PieceList, RollMut}; -use fluke_maybe_uring::io::WriteOwned; - -/// Write H2 frames to the transport, from a channel -pub(crate) async fn h2_write_loop( - mut ev_rx: mpsc::Receiver, - mut transport_w: impl WriteOwned, - mut out_scratch: RollMut, -) -> eyre::Result<()> { - let mut hpack_enc = fluke_hpack::Encoder::new(); - - while let Some(ev) = ev_rx.recv().await { - trace!("h2_write_loop: received H2 event"); - debug!(?ev, "Writing"); - - match ev.payload { - H2EventPayload::Headers(res) => { - let flags = HeadersFlags::EndHeaders; - let mut frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); - - // TODO: don't allocate so much for headers. all `encode_into` - // wants is an `IntoIter`, we can definitely have a custom iterator - // that operates on all this instead of using a `Vec`. - - // TODO: limit header size - let mut headers: Vec<(&[u8], &[u8])> = vec![]; - headers.push((b":status", res.status.as_str().as_bytes())); - for (name, value) in res.headers.iter() { - if name == http::header::TRANSFER_ENCODING { - // do not set transfer-encoding: chunked when doing HTTP/2 - continue; - } - headers.push((name.as_str().as_bytes(), value)); - } - - assert_eq!(out_scratch.len(), 0); - hpack_enc.encode_into(headers, &mut out_scratch)?; - let fragment_block = out_scratch.take_all(); - - frame.len = fragment_block.len() as u32; - let frame_roll = frame.into_roll(&mut out_scratch)?; - - transport_w - .writev_all(PieceList::default().with(frame_roll).with(fragment_block)) - .await - .wrap_err("writing headers")?; - } - H2EventPayload::BodyChunk(chunk) => { - let flags = BitFlags::::default(); - let frame = Frame::new(FrameType::Data(flags), ev.stream_id) - .with_len(chunk.len().try_into().unwrap()); - let frame_roll = frame.into_roll(&mut out_scratch)?; - transport_w - .writev_all(PieceList::default().with(frame_roll).with(chunk)) - .await - .wrap_err("writing bodychunk")?; - } - H2EventPayload::BodyEnd => { - // FIXME: this should transition the stream to `Closed` - // state (or at the very least `HalfClosedLocal`). - // Either way, whoever owns the stream state should know - // about it, cf. https://github.com/hapsoc/fluke/issues/123 - - let flags = DataFlags::EndStream; - let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); - transport_w - .write_all(frame.into_roll(&mut out_scratch)?) - .await - .wrap_err("writing BodyEnd")?; - } - } - } - - Ok(()) -} From 819eb1242d00e41f55fddf8452f9477e2c5b1635 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:24:24 +0100 Subject: [PATCH 12/17] Remove read module, move it back into server.rs --- crates/fluke/src/h2/mod.rs | 1 - crates/fluke/src/h2/read.rs | 992 --------------------------------- crates/fluke/src/h2/server.rs | 994 +++++++++++++++++++++++++++++++++- 3 files changed, 985 insertions(+), 1002 deletions(-) delete mode 100644 crates/fluke/src/h2/read.rs diff --git a/crates/fluke/src/h2/mod.rs b/crates/fluke/src/h2/mod.rs index 0761ec56..ef609f0c 100644 --- a/crates/fluke/src/h2/mod.rs +++ b/crates/fluke/src/h2/mod.rs @@ -8,5 +8,4 @@ pub(crate) mod parse; mod body; mod encode; -mod read; mod types; diff --git a/crates/fluke/src/h2/read.rs b/crates/fluke/src/h2/read.rs deleted file mode 100644 index 9cac6d49..00000000 --- a/crates/fluke/src/h2/read.rs +++ /dev/null @@ -1,992 +0,0 @@ -use std::{ - borrow::Cow, - io::Write, - rc::Rc, - sync::atomic::{AtomicU32, Ordering}, -}; - -use byteorder::{BigEndian, WriteBytesExt}; -use enumflags2::BitFlags; -use eyre::Context; -use fluke_buffet::{Piece, PieceList, PieceStr, Roll, RollMut}; -use fluke_maybe_uring::io::{ReadOwned, WriteOwned}; -use http::{ - header, - uri::{Authority, PathAndQuery, Scheme}, - HeaderName, Version, -}; -use nom::Finish; -use smallvec::{smallvec, SmallVec}; -use tokio::sync::mpsc; -use tracing::{debug, trace, warn}; - -use crate::{ - h2::{ - parse::{parse_reserved_and_u31, Frame, FrameType, PrioritySpec}, - types::H2ConnectionError, - }, - util::read_and_parse, - ExpectResponseHeaders, Headers, Method, Request, Responder, ServerDriver, -}; - -use super::{ - body::{H2Body, H2BodyItem, PieceOrTrailers}, - encode::{EncoderState, H2Encoder}, - parse::{ - ContinuationFlags, DataFlags, HeadersFlags, PingFlags, Settings, SettingsFlags, StreamId, - }, - types::{ConnState, H2Event, H2EventPayload, H2StreamError, HeadersOrTrailers, StreamState}, -}; - -/// Reads and processes h2 frames from the client. -pub(crate) struct H2ReadContext { - driver: Rc, - state: ConnState, - hpack_dec: fluke_hpack::Decoder<'static>, - hpack_enc: fluke_hpack::Encoder<'static>, - out_scratch: RollMut, - - /// Whether we've received a GOAWAY frame. - pub goaway_recv: bool, - - /// TODO: encapsulate into a framer, don't - /// allow direct access from context methods - transport_w: W, - - ev_tx: mpsc::Sender, - ev_rx: mpsc::Receiver, -} - -impl H2ReadContext { - pub(crate) fn new( - driver: Rc, - state: ConnState, - transport_w: W, - out_scratch: RollMut, - ) -> Self { - let mut hpack_dec = fluke_hpack::Decoder::new(); - hpack_dec - .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); - - let hpack_enc = fluke_hpack::Encoder::new(); - - let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); - - Self { - driver, - ev_tx, - ev_rx, - state, - hpack_dec, - hpack_enc, - out_scratch, - goaway_recv: false, - transport_w, - } - } - - /// Reads and process h2 frames from the client. - pub(crate) async fn work( - mut self, - client_buf: RollMut, - transport_r: impl ReadOwned, - ) -> eyre::Result<()> { - let mut goaway_err: Option = None; - - { - // read frames and send them into an mpsc buffer of size 1 - let (tx, rx) = mpsc::channel::<(Frame, Roll)>(1); - - // store max frame size setting as an atomic so we can share it across tasks - // FIXME: the process_task should update this - let max_frame_size = Rc::new(AtomicU32::new(self.state.self_settings.max_frame_size)); - - let mut deframe_task = std::pin::pin!(Self::deframe_loop( - client_buf, - transport_r, - tx, - max_frame_size - )); - let mut process_task = std::pin::pin!(self.process_loop(rx)); - - tokio::select! { - res = &mut deframe_task => { - debug!(?res, "h2 deframe task finished"); - - if let Err(H2ConnectionError::ReadError(e)) = res { - let mut should_ignore_err = false; - - // if this is a connection reset and we've sent a goaway, ignore it - if let Some(io_error) = e.root_cause().downcast_ref::() { - if io_error.kind() == std::io::ErrorKind::ConnectionReset { - should_ignore_err = true; - } - } - - if !should_ignore_err { - return Err(e.wrap_err("h2 io")); - } - } - - if let Err(e) = (&mut process_task).await { - debug!("h2 process task finished with error: {e}"); - return Err(e).wrap_err("h2 process"); - } - } - res = &mut process_task => { - debug!(?res, "h2 process task finished"); - - if let Err(err) = res { - goaway_err = Some(err); - } - } - } - } - - if let Some(err) = goaway_err { - let error_code = err.as_known_error_code(); - debug!("Connection error: {err} ({err:?}) (code {error_code:?})"); - - // TODO: don't heap-allocate here - let additional_debug_data = format!("{err}").into_bytes(); - - // TODO: figure out graceful shutdown: this would involve sending a goaway - // before this point, and processing all the connections we've accepted - debug!(last_stream_id = %self.state.last_stream_id, ?error_code, "Sending GoAway"); - let payload = - self.out_scratch - .put_to_roll(8 + additional_debug_data.len(), |mut slice| { - slice.write_u32::(self.state.last_stream_id.0)?; - slice.write_u32::(error_code.repr())?; - slice.write_all(additional_debug_data.as_slice())?; - - Ok(()) - })?; - - let frame = Frame::new(FrameType::GoAway, StreamId::CONNECTION); - self.write_frame(frame, payload).await?; - } - - Ok(()) - } - - async fn deframe_loop( - mut client_buf: RollMut, - mut transport_r: impl ReadOwned, - tx: mpsc::Sender<(Frame, Roll)>, - max_frame_size: Rc, - ) -> Result<(), H2ConnectionError> { - 'read_frames: loop { - const MAX_FRAME_HEADER_SIZE: usize = 128; - let frame; - let frame_res = read_and_parse( - Frame::parse, - &mut transport_r, - client_buf, - MAX_FRAME_HEADER_SIZE, - ) - .await; - - let maybe_frame = match frame_res { - Ok(inner) => inner, - Err(e) => return Err(H2ConnectionError::ReadError(e)), - }; - (client_buf, frame) = match maybe_frame { - Some((client_buf, frame)) => (client_buf, frame), - None => { - debug!("Peer went away before sending a frame"); - break 'read_frames; - } - }; - - debug!(?frame, "Received"); - - let max_frame_size = max_frame_size.load(Ordering::Relaxed); - if frame.len > max_frame_size { - return Err(H2ConnectionError::FrameTooLarge { - frame_type: frame.frame_type, - frame_size: frame.len, - max_frame_size, - }); - } - - let mut payload; - (client_buf, payload) = match read_and_parse( - nom::bytes::streaming::take(frame.len as usize), - &mut transport_r, - client_buf, - frame.len as usize, - ) - .await? - { - Some((client_buf, payload)) => (client_buf, payload), - None => { - return Err(H2ConnectionError::IncompleteFrame { - frame_type: frame.frame_type, - frame_size: frame.len, - }) - } - }; - - let has_padding = match frame.frame_type { - FrameType::Data(flags) => flags.contains(DataFlags::Padded), - FrameType::Headers(flags) => flags.contains(HeadersFlags::Padded), - _ => false, - }; - - if has_padding { - if payload.is_empty() { - return Err(H2ConnectionError::PaddedFrameEmpty { - frame_type: frame.frame_type, - }); - } - - let padding_length_roll; - (padding_length_roll, payload) = payload.split_at(1); - let padding_length = padding_length_roll[0] as usize; - if payload.len() < padding_length { - return Err(H2ConnectionError::PaddedFrameTooShort { - frame_type: frame.frame_type, - padding_length, - frame_size: frame.len, - }); - } - - // padding is on the end of the payload - let at = payload.len() - padding_length; - (payload, _) = payload.split_at(at); - } - - if tx.send((frame, payload)).await.is_err() { - debug!("h2 deframer: receiver dropped, closing connection"); - return Ok(()); - } - } - - Ok(()) - } - - async fn process_loop( - &mut self, - mut rx: mpsc::Receiver<(Frame, Roll)>, - ) -> Result<(), H2ConnectionError> { - loop { - tokio::select! { - ev = self.ev_rx.recv() => { - match ev { - Some(ev) => self.handle_event(ev).await?, - None => unreachable!("the context owns a copy of the sender, and this method has &mut self, so the sender can't be dropped while this method is running"), - } - }, - maybe_frame = rx.recv() => { - if let Some((frame, payload)) = maybe_frame { - self.process_frame(frame, payload, &mut rx).await?; - } else { - debug!("h2 process task: peer hung up"); - break; - } - } - } - } - - Ok(()) - } - - async fn handle_event(&mut self, ev: H2Event) -> Result<(), H2ConnectionError> { - match ev.payload { - H2EventPayload::Headers(res) => { - let flags = HeadersFlags::EndHeaders; - let frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); - - // TODO: don't allocate so much for headers. all `encode_into` - // wants is an `IntoIter`, we can definitely have a custom iterator - // that operates on all this instead of using a `Vec`. - - // TODO: limit header size - let mut headers: Vec<(&[u8], &[u8])> = vec![]; - headers.push((b":status", res.status.as_str().as_bytes())); - for (name, value) in res.headers.iter() { - if name == http::header::TRANSFER_ENCODING { - // do not set transfer-encoding: chunked when doing HTTP/2 - continue; - } - headers.push((name.as_str().as_bytes(), value)); - } - - assert_eq!(self.out_scratch.len(), 0); - self.hpack_enc - .encode_into(headers, &mut self.out_scratch) - .map_err(H2ConnectionError::WriteError)?; - let payload = self.out_scratch.take_all(); - - self.write_frame(frame, payload).await?; - } - H2EventPayload::BodyChunk(chunk) => { - let flags = BitFlags::::default(); - let frame = Frame::new(FrameType::Data(flags), ev.stream_id); - - self.write_frame(frame, chunk).await?; - } - H2EventPayload::BodyEnd => { - // FIXME: this should transition the stream to `Closed` - // state (or at the very least `HalfClosedLocal`). - // Either way, whoever owns the stream state should know - // about it, cf. https://github.com/hapsoc/fluke/issues/123 - - let flags = DataFlags::EndStream; - let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); - self.write_frame(frame, Roll::empty()).await?; - } - } - - Ok(()) - } - - async fn write_frame( - &mut self, - frame: Frame, - payload: impl Into, - ) -> Result<(), H2ConnectionError> { - let payload = payload.into(); - - match &frame.frame_type { - FrameType::Data(headers) => { - if headers.contains(DataFlags::EndStream) { - // if the stream is open, this transitions to HalfClosedLocal. - if let Some(ss) = self.state.streams.get_mut(&frame.stream_id) { - match ss { - StreamState::Open(_) => { - // transition through StreamState::HalfClosedRemote - // so we don't have to remove/re-insert. - let mut entry = StreamState::HalfClosedRemote; - std::mem::swap(&mut entry, ss); - - let body_tx = match entry { - StreamState::Open(body_tx) => body_tx, - _ => unreachable!(), - }; - - *ss = StreamState::HalfClosedLocal(body_tx); - } - _ => { - // transition to closed - self.state.streams.remove(&frame.stream_id); - } - } - } - } - } - _ => { - // muffin. - } - } - - let frame_roll = frame.into_roll(&mut self.out_scratch)?; - - if payload.is_empty() { - self.transport_w - .write_all(frame_roll) - .await - .map_err(H2ConnectionError::WriteError)?; - } else { - self.transport_w - .writev_all(PieceList::default().with(frame_roll).with(payload)) - .await - .map_err(H2ConnectionError::WriteError)?; - } - - Ok(()) - } - - async fn process_frame( - &mut self, - frame: Frame, - mut payload: Roll, - rx: &mut mpsc::Receiver<(Frame, Roll)>, - ) -> Result<(), H2ConnectionError> { - match frame.frame_type { - FrameType::Data(flags) => { - let ss = self.state.streams.get_mut(&frame.stream_id).ok_or( - H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }, - )?; - - match ss { - StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { - if body_tx - .send(Ok(PieceOrTrailers::Piece(payload.into()))) - .await - .is_err() - { - warn!("TODO: The body is being ignored, we should reset the stream"); - } - - if flags.contains(DataFlags::EndStream) { - // if we're HalfClosedLocal, this transitions to Closed - // otherwise, it transitions to HalfClosedRemote - if matches!(ss, StreamState::Open(_)) { - *ss = StreamState::HalfClosedRemote; - } else { - self.state.streams.remove(&frame.stream_id); - } - } - } - StreamState::HalfClosedRemote => { - debug!( - stream_id = %frame.stream_id, - "Received data for closed stream" - ); - self.rst(frame.stream_id, H2StreamError::StreamClosed) - .await?; - } - } - } - FrameType::Headers(flags) => { - if flags.contains(HeadersFlags::Priority) { - let pri_spec; - (payload, pri_spec) = PrioritySpec::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - debug!(exclusive = %pri_spec.exclusive, stream_dependency = ?pri_spec.stream_dependency, weight = %pri_spec.weight, "received priority, exclusive"); - - if pri_spec.stream_dependency == frame.stream_id { - return Err(H2ConnectionError::HeadersInvalidPriority { - stream_id: frame.stream_id, - }); - } - } - - let headers_or_trailers; - let mode; - - match self.state.streams.get_mut(&frame.stream_id) { - None => { - headers_or_trailers = HeadersOrTrailers::Headers; - debug!( - stream_id = %frame.stream_id, - last_stream_id = %self.state.last_stream_id, - next_stream_count = %self.state.streams.len() + 1, - "Receiving headers", - ); - - if frame.stream_id.is_server_initiated() { - return Err(H2ConnectionError::ClientSidShouldBeOdd); - } - - if frame.stream_id <= self.state.last_stream_id { - debug!( - frame_stream_id = %frame.stream_id, - last_stream_id = %self.state.last_stream_id, - "Received headers for invalid stream ID" - ); - - // this stream may have existed, but it no longer does: - return Err(H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }); - } else { - // TODO: if we're shutting down, ignore streams higher - // than the last one we accepted. - - let max_concurrent_streams = - self.state.self_settings.max_concurrent_streams; - let num_streams_if_accept = self.state.streams.len() + 1; - if num_streams_if_accept > max_concurrent_streams as _ { - // reset the stream, indicating we refused it - self.rst(frame.stream_id, H2StreamError::RefusedStream) - .await?; - - // but we still need to skip over any continuation frames - mode = ReadHeadersMode::Skip; - } else { - self.state.last_stream_id = frame.stream_id; - mode = ReadHeadersMode::Process; - } - } - } - Some(StreamState::Open(_) | StreamState::HalfClosedLocal(_)) => { - headers_or_trailers = HeadersOrTrailers::Trailers; - debug!("Receiving trailers for stream {}", frame.stream_id); - - if flags.contains(HeadersFlags::EndStream) { - // good, that's what we expect - mode = ReadHeadersMode::Process; - } else { - // ignore trailers, we're not accepting the stream - mode = ReadHeadersMode::Skip; - - self.rst(frame.stream_id, H2StreamError::TrailersNotEndStream) - .await?; - } - } - Some(StreamState::HalfClosedRemote) => { - return Err(H2ConnectionError::StreamClosed { - stream_id: frame.stream_id, - }); - } - } - - self.read_headers( - headers_or_trailers, - mode, - flags, - frame.stream_id, - payload, - rx, - ) - .await?; - } - FrameType::Priority => { - let pri_spec = match PrioritySpec::parse(payload) { - Ok((_rest, pri_spec)) => pri_spec, - Err(e) => { - todo!("handle connection error: invalid priority frame {e}") - } - }; - debug!(?pri_spec, "received priority frame"); - - if pri_spec.stream_dependency == frame.stream_id { - return Err(H2ConnectionError::HeadersInvalidPriority { - stream_id: frame.stream_id, - }); - } - } - // note: this always unconditionally transitions the stream to closed - FrameType::RstStream => match self.state.streams.remove(&frame.stream_id) { - None => { - return Err(H2ConnectionError::RstStreamForUnknownStream { - stream_id: frame.stream_id, - }) - } - Some(ss) => match ss { - StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { - _ = body_tx - .send(Err(H2StreamError::ReceivedRstStream.into())) - .await; - } - StreamState::HalfClosedRemote => { - // good - } - }, - }, - FrameType::Settings(s) => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::SettingsWithNonZeroStreamId { - stream_id: frame.stream_id, - }); - } - - if s.contains(SettingsFlags::Ack) { - debug!("Peer has acknowledged our settings, cool"); - if !payload.is_empty() { - return Err(H2ConnectionError::SettingsAckWithPayload { - len: payload.len() as _, - }); - } - } else { - let (_, settings) = Settings::parse(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - self.hpack_enc - .set_max_table_size(settings.header_table_size as usize); - - debug!(?settings, "Received settings"); - self.state.peer_settings = settings; - - let frame = Frame::new( - FrameType::Settings(SettingsFlags::Ack.into()), - StreamId::CONNECTION, - ); - self.write_frame(frame, Roll::empty()).await?; - } - } - FrameType::PushPromise => { - return Err(H2ConnectionError::ClientSentPushPromise); - } - FrameType::Ping(flags) => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::PingFrameWithNonZeroStreamId { - stream_id: frame.stream_id, - }); - } - - if frame.len != 8 { - return Err(H2ConnectionError::PingFrameInvalidLength { len: frame.len }); - } - - if flags.contains(PingFlags::Ack) { - // TODO: check that payload matches the one we sent? - return Ok(()); - } - - // send pong frame - let flags = PingFlags::Ack.into(); - let frame = Frame::new(FrameType::Ping(flags), StreamId::CONNECTION) - .with_len(payload.len() as u32); - self.write_frame(frame, payload).await?; - } - FrameType::GoAway => { - if frame.stream_id != StreamId::CONNECTION { - return Err(H2ConnectionError::GoAwayWithNonZeroStreamId { - stream_id: frame.stream_id, - }); - } - - self.goaway_recv = true; - - // TODO: this should probably have other effects than setting - // this flag. - } - FrameType::WindowUpdate => { - if payload.len() != 4 { - return Err(H2ConnectionError::WindowUpdateInvalidLength { - len: payload.len() as _, - }); - } - - let increment; - (_, (_, increment)) = parse_reserved_and_u31(payload) - .finish() - .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; - - if increment == 0 { - return Err(H2ConnectionError::WindowUpdateZeroIncrement); - } - - if frame.stream_id == StreamId::CONNECTION { - debug!("TODO: ignoring connection-wide window update"); - } else { - match self.state.streams.get_mut(&frame.stream_id) { - None => { - return Err(H2ConnectionError::WindowUpdateForUnknownStream { - stream_id: frame.stream_id, - }); - } - Some(_ss) => { - debug!("TODO: handle window update for stream {}", frame.stream_id) - } - } - } - } - FrameType::Continuation(_flags) => { - return Err(H2ConnectionError::UnexpectedContinuationFrame { - stream_id: frame.stream_id, - }); - } - FrameType::Unknown(ft) => { - trace!( - "ignoring unknown frame with type 0x{:x}, flags 0x{:x}", - ft.ty, - ft.flags - ); - } - } - - Ok(()) - } - - /// Send a RST_STREAM frame to the peer. - async fn rst( - &mut self, - stream_id: StreamId, - e: H2StreamError, - ) -> Result<(), H2ConnectionError> { - self.state.streams.remove(&stream_id); - - let error_code = e.as_known_error_code(); - debug!("Sending rst because: {e} (known error code: {error_code:?})"); - - debug!(%stream_id, ?error_code, "Sending RstStream"); - let payload = self.out_scratch.put_to_roll(4, |mut slice| { - slice.write_u32::(error_code.repr())?; - Ok(()) - })?; - - let frame = Frame::new(FrameType::RstStream, stream_id) - .with_len((payload.len()).try_into().unwrap()); - self.write_frame(frame, payload).await?; - - Ok(()) - } - - async fn read_headers( - &mut self, - headers_or_trailers: HeadersOrTrailers, - mode: ReadHeadersMode, - flags: BitFlags, - stream_id: StreamId, - payload: Roll, - rx: &mut mpsc::Receiver<(Frame, Roll)>, - ) -> Result<(), H2ConnectionError> { - let end_stream = flags.contains(HeadersFlags::EndStream); - - enum Data { - Single(Roll), - Multi(SmallVec<[Roll; 2]>), - } - - let data = if flags.contains(HeadersFlags::EndHeaders) { - // good, no continuation frames needed - Data::Single(payload) - } else { - // read continuation frames - - #[allow(unused, clippy::let_unit_value)] - let flags = (); // don't accidentally use the `flags` variable - - let mut fragments = smallvec![payload]; - - loop { - let (continuation_frame, continuation_payload) = match rx.recv().await { - Some(t) => t, - None => { - // even though this error is "for a stream", it's a - // connection error, because it means the peer doesn't - // know how to speak HTTP/2. - return Err(H2ConnectionError::ExpectedContinuationFrame { - stream_id, - frame_type: None, - }); - } - }; - - if stream_id != continuation_frame.stream_id { - return Err(H2ConnectionError::ExpectedContinuationForStream { - stream_id, - continuation_stream_id: continuation_frame.stream_id, - }); - } - - let cont_flags = match continuation_frame.frame_type { - FrameType::Continuation(flags) => flags, - other => { - return Err(H2ConnectionError::ExpectedContinuationFrame { - stream_id, - frame_type: Some(other), - }) - } - }; - - // add fragment - fragments.push(continuation_payload); - - if cont_flags.contains(ContinuationFlags::EndHeaders) { - // we're done - break; - } - } - - Data::Multi(fragments) - }; - - if matches!(mode, ReadHeadersMode::Skip) { - // that's all we need to do: we're not actually validating the - // headers, we already send a RST - return Ok(()); - } - - let mut method: Option = None; - let mut scheme: Option = None; - let mut path: Option = None; - let mut authority: Option = None; - - let mut headers = Headers::default(); - - // TODO: find a way to propagate errors from here - probably will have to change - // the function signature in fluke-hpack, or just write to some captured - // error - let on_header_pair = |key: Cow<[u8]>, value: Cow<[u8]>| { - debug!( - "{headers_or_trailers:?} | {}: {}", - std::str::from_utf8(&key).unwrap_or(""), // TODO: does this hurt performance when debug logging is disabled? - std::str::from_utf8(&value).unwrap_or(""), - ); - - if &key[..1] == b":" { - if matches!(headers_or_trailers, HeadersOrTrailers::Trailers) { - // TODO: proper error handling - panic!("trailers cannot contain pseudo-headers"); - } - - // it's a pseudo-header! - // TODO: reject headers that occur after pseudo-headers - match &key[1..] { - b"method" => { - // TODO: error handling - let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); - if method.replace(Method::from(value)).is_some() { - unreachable!(); // No duplicate allowed. - } - } - b"scheme" => { - // TODO: error handling - let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); - if scheme.replace(value.parse().unwrap()).is_some() { - unreachable!(); // No duplicate allowed. - } - } - b"path" => { - // TODO: error handling - let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); - if value.len() == 0 || path.replace(value).is_some() { - unreachable!(); // No empty path nor duplicate allowed. - } - } - b"authority" => { - // TODO: error handling - let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); - if authority.replace(value.parse().unwrap()).is_some() { - unreachable!(); // No duplicate allowed. (h2spec doesn't seem to test for - // this case but rejecting duplicates seems reasonable.) - } - } - _ => { - debug!("ignoring pseudo-header"); - } - } - } else { - // TODO: what do we do in case of malformed header names? - // ignore it? return a 400? - let name = HeaderName::from_bytes(&key[..]).expect("malformed header name"); - let value: Piece = value.to_vec().into(); - headers.append(name, value); - } - }; - - match data { - Data::Single(payload) => { - self.hpack_dec - .decode_with_cb(&payload[..], on_header_pair) - .map_err(|e| H2ConnectionError::CompressionError(format!("{e:?}")))?; - } - Data::Multi(fragments) => { - let total_len = fragments.iter().map(|f| f.len()).sum(); - // this is a slow path, let's do a little heap allocation. we could - // be using `RollMut` for this, but it would probably need to resize - // a bunch - let mut payload = Vec::with_capacity(total_len); - for frag in &fragments { - payload.extend_from_slice(&frag[..]); - } - self.hpack_dec - .decode_with_cb(&payload[..], on_header_pair) - .map_err(|e| H2ConnectionError::CompressionError(format!("{e:?}")))?; - } - }; - - match headers_or_trailers { - HeadersOrTrailers::Headers => { - // TODO: cf. https://httpwg.org/specs/rfc9113.html#HttpRequest - // A server SHOULD treat a request as malformed if it contains a Host header - // field that identifies an entity that differs from the entity in the - // ":authority" pseudo-header field. - - // TODO: proper error handling (return 400) - let method = method.unwrap(); - let scheme = scheme.unwrap(); - - let path = path.unwrap(); - let path_and_query: PathAndQuery = path.parse().unwrap(); - - let authority = match authority { - Some(authority) => Some(authority), - None => headers - .get(header::HOST) - .map(|host| host.as_str().unwrap().parse().unwrap()), - }; - - let mut uri_parts: http::uri::Parts = Default::default(); - uri_parts.scheme = Some(scheme); - uri_parts.authority = authority; - uri_parts.path_and_query = Some(path_and_query); - - let uri = http::uri::Uri::from_parts(uri_parts).unwrap(); - - let req = Request { - method, - uri, - version: Version::HTTP_2, - headers, - }; - - let responder = Responder { - encoder: H2Encoder { - stream_id, - tx: self.ev_tx.clone(), - state: EncoderState::ExpectResponseHeaders, - }, - // TODO: why tf is this state encoded twice? is that really - // necessary? I know it's for typestates and H2Encoder needs - // to look up its state at runtime I guess, but.. that's not great? - state: ExpectResponseHeaders, - }; - - let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? - - let req_body = H2Body { - // FIXME: that's not right. h2 requests can still specify - // a content-length - content_length: if end_stream { Some(0) } else { None }, - eof: end_stream, - rx: piece_rx, - }; - - fluke_maybe_uring::spawn({ - let driver = self.driver.clone(); - async move { - let mut req_body = req_body; - let responder = responder; - - match driver.handle(req, &mut req_body, responder).await { - Ok(_responder) => { - debug!("Handler completed successfully, gave us a responder"); - } - Err(e) => { - // TODO: actually handle that error. - debug!("Handler returned an error: {e}") - } - } - } - }); - - self.state.streams.insert( - stream_id, - if end_stream { - StreamState::HalfClosedRemote - } else { - StreamState::Open(piece_tx) - }, - ); - } - HeadersOrTrailers::Trailers => { - match self.state.streams.get_mut(&stream_id) { - Some(StreamState::Open(body_tx)) => { - if body_tx - .send(Ok(PieceOrTrailers::Trailers(Box::new(headers)))) - .await - .is_err() - { - // the body is being ignored, but there's no point in - // resetting the stream since we just got the end of it - } - } - _ => { - unreachable!("stream state should be open when we receive trailers") - } - } - self.state.streams.remove(&stream_id); - } - } - - Ok(()) - } -} - -enum ReadHeadersMode { - // we're accepting the stream or processing trailers, we want to - // process the headers we read. - Process, - // we're refusing the stream, we want to skip over the headers we read. - Skip, -} diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index b7793bdf..16929fa0 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -1,18 +1,41 @@ -use std::rc::Rc; +use std::{ + borrow::Cow, + io::Write, + rc::Rc, + sync::atomic::{AtomicU32, Ordering}, +}; -use tracing::debug; +use byteorder::{BigEndian, WriteBytesExt}; +use enumflags2::BitFlags; +use eyre::Context; +use fluke_buffet::{Piece, PieceList, PieceStr, Roll, RollMut}; +use fluke_maybe_uring::io::{ReadOwned, WriteOwned}; +use http::{ + header, + uri::{Authority, PathAndQuery, Scheme}, + HeaderName, Version, +}; +use nom::Finish; +use smallvec::{smallvec, SmallVec}; +use tokio::sync::mpsc; +use tracing::{debug, trace, warn}; use crate::{ h2::{ - parse::{self, Frame, FrameType, StreamId}, - read::H2ReadContext, - types::ConnState, + body::{H2Body, H2BodyItem, PieceOrTrailers}, + encode::{EncoderState, H2Encoder}, + parse::{ + self, parse_reserved_and_u31, ContinuationFlags, DataFlags, Frame, FrameType, + HeadersFlags, PingFlags, PrioritySpec, Settings, SettingsFlags, StreamId, + }, + types::{ + ConnState, H2ConnectionError, H2Event, H2EventPayload, H2StreamError, + HeadersOrTrailers, StreamState, + }, }, util::read_and_parse, - ServerDriver, + ExpectResponseHeaders, Headers, Method, Request, Responder, ServerDriver, }; -use fluke_buffet::RollMut; -use fluke_maybe_uring::io::{ReadOwned, WriteOwned}; /// HTTP/2 server configuration pub struct ServerConf { @@ -68,10 +91,963 @@ pub async fn serve( debug!("sent settings frame"); } - H2ReadContext::new(driver.clone(), state, transport_w, out_scratch) + ServerContext::new(driver.clone(), state, transport_w, out_scratch) .work(client_buf, transport_r) .await?; debug!("finished serving"); Ok(()) } + +/// Reads and processes h2 frames from the client. +pub(crate) struct ServerContext { + driver: Rc, + state: ConnState, + hpack_dec: fluke_hpack::Decoder<'static>, + hpack_enc: fluke_hpack::Encoder<'static>, + out_scratch: RollMut, + + /// Whether we've received a GOAWAY frame. + pub goaway_recv: bool, + + /// TODO: encapsulate into a framer, don't + /// allow direct access from context methods + transport_w: W, + + ev_tx: mpsc::Sender, + ev_rx: mpsc::Receiver, +} + +impl ServerContext { + pub(crate) fn new( + driver: Rc, + state: ConnState, + transport_w: W, + out_scratch: RollMut, + ) -> Self { + let mut hpack_dec = fluke_hpack::Decoder::new(); + hpack_dec + .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); + + let hpack_enc = fluke_hpack::Encoder::new(); + + let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); + + Self { + driver, + ev_tx, + ev_rx, + state, + hpack_dec, + hpack_enc, + out_scratch, + goaway_recv: false, + transport_w, + } + } + + /// Reads and process h2 frames from the client. + pub(crate) async fn work( + mut self, + client_buf: RollMut, + transport_r: impl ReadOwned, + ) -> eyre::Result<()> { + let mut goaway_err: Option = None; + + { + // read frames and send them into an mpsc buffer of size 1 + let (tx, rx) = mpsc::channel::<(Frame, Roll)>(1); + + // store max frame size setting as an atomic so we can share it across tasks + // FIXME: the process_task should update this + let max_frame_size = Rc::new(AtomicU32::new(self.state.self_settings.max_frame_size)); + + let mut deframe_task = std::pin::pin!(Self::deframe_loop( + client_buf, + transport_r, + tx, + max_frame_size + )); + let mut process_task = std::pin::pin!(self.process_loop(rx)); + + tokio::select! { + res = &mut deframe_task => { + debug!(?res, "h2 deframe task finished"); + + if let Err(H2ConnectionError::ReadError(e)) = res { + let mut should_ignore_err = false; + + // if this is a connection reset and we've sent a goaway, ignore it + if let Some(io_error) = e.root_cause().downcast_ref::() { + if io_error.kind() == std::io::ErrorKind::ConnectionReset { + should_ignore_err = true; + } + } + + if !should_ignore_err { + return Err(e.wrap_err("h2 io")); + } + } + + if let Err(e) = (&mut process_task).await { + debug!("h2 process task finished with error: {e}"); + return Err(e).wrap_err("h2 process"); + } + } + res = &mut process_task => { + debug!(?res, "h2 process task finished"); + + if let Err(err) = res { + goaway_err = Some(err); + } + } + } + } + + if let Some(err) = goaway_err { + let error_code = err.as_known_error_code(); + debug!("Connection error: {err} ({err:?}) (code {error_code:?})"); + + // TODO: don't heap-allocate here + let additional_debug_data = format!("{err}").into_bytes(); + + // TODO: figure out graceful shutdown: this would involve sending a goaway + // before this point, and processing all the connections we've accepted + debug!(last_stream_id = %self.state.last_stream_id, ?error_code, "Sending GoAway"); + let payload = + self.out_scratch + .put_to_roll(8 + additional_debug_data.len(), |mut slice| { + slice.write_u32::(self.state.last_stream_id.0)?; + slice.write_u32::(error_code.repr())?; + slice.write_all(additional_debug_data.as_slice())?; + + Ok(()) + })?; + + let frame = Frame::new(FrameType::GoAway, StreamId::CONNECTION); + self.write_frame(frame, payload).await?; + } + + Ok(()) + } + + async fn deframe_loop( + mut client_buf: RollMut, + mut transport_r: impl ReadOwned, + tx: mpsc::Sender<(Frame, Roll)>, + max_frame_size: Rc, + ) -> Result<(), H2ConnectionError> { + 'read_frames: loop { + const MAX_FRAME_HEADER_SIZE: usize = 128; + let frame; + let frame_res = read_and_parse( + Frame::parse, + &mut transport_r, + client_buf, + MAX_FRAME_HEADER_SIZE, + ) + .await; + + let maybe_frame = match frame_res { + Ok(inner) => inner, + Err(e) => return Err(H2ConnectionError::ReadError(e)), + }; + (client_buf, frame) = match maybe_frame { + Some((client_buf, frame)) => (client_buf, frame), + None => { + debug!("Peer went away before sending a frame"); + break 'read_frames; + } + }; + + debug!(?frame, "Received"); + + let max_frame_size = max_frame_size.load(Ordering::Relaxed); + if frame.len > max_frame_size { + return Err(H2ConnectionError::FrameTooLarge { + frame_type: frame.frame_type, + frame_size: frame.len, + max_frame_size, + }); + } + + let mut payload; + (client_buf, payload) = match read_and_parse( + nom::bytes::streaming::take(frame.len as usize), + &mut transport_r, + client_buf, + frame.len as usize, + ) + .await? + { + Some((client_buf, payload)) => (client_buf, payload), + None => { + return Err(H2ConnectionError::IncompleteFrame { + frame_type: frame.frame_type, + frame_size: frame.len, + }) + } + }; + + let has_padding = match frame.frame_type { + FrameType::Data(flags) => flags.contains(DataFlags::Padded), + FrameType::Headers(flags) => flags.contains(HeadersFlags::Padded), + _ => false, + }; + + if has_padding { + if payload.is_empty() { + return Err(H2ConnectionError::PaddedFrameEmpty { + frame_type: frame.frame_type, + }); + } + + let padding_length_roll; + (padding_length_roll, payload) = payload.split_at(1); + let padding_length = padding_length_roll[0] as usize; + if payload.len() < padding_length { + return Err(H2ConnectionError::PaddedFrameTooShort { + frame_type: frame.frame_type, + padding_length, + frame_size: frame.len, + }); + } + + // padding is on the end of the payload + let at = payload.len() - padding_length; + (payload, _) = payload.split_at(at); + } + + if tx.send((frame, payload)).await.is_err() { + debug!("h2 deframer: receiver dropped, closing connection"); + return Ok(()); + } + } + + Ok(()) + } + + async fn process_loop( + &mut self, + mut rx: mpsc::Receiver<(Frame, Roll)>, + ) -> Result<(), H2ConnectionError> { + loop { + tokio::select! { + ev = self.ev_rx.recv() => { + match ev { + Some(ev) => self.handle_event(ev).await?, + None => unreachable!("the context owns a copy of the sender, and this method has &mut self, so the sender can't be dropped while this method is running"), + } + }, + maybe_frame = rx.recv() => { + if let Some((frame, payload)) = maybe_frame { + self.process_frame(frame, payload, &mut rx).await?; + } else { + debug!("h2 process task: peer hung up"); + break; + } + } + } + } + + Ok(()) + } + + async fn handle_event(&mut self, ev: H2Event) -> Result<(), H2ConnectionError> { + match ev.payload { + H2EventPayload::Headers(res) => { + let flags = HeadersFlags::EndHeaders; + let frame = Frame::new(FrameType::Headers(flags.into()), ev.stream_id); + + // TODO: don't allocate so much for headers. all `encode_into` + // wants is an `IntoIter`, we can definitely have a custom iterator + // that operates on all this instead of using a `Vec`. + + // TODO: limit header size + let mut headers: Vec<(&[u8], &[u8])> = vec![]; + headers.push((b":status", res.status.as_str().as_bytes())); + for (name, value) in res.headers.iter() { + if name == http::header::TRANSFER_ENCODING { + // do not set transfer-encoding: chunked when doing HTTP/2 + continue; + } + headers.push((name.as_str().as_bytes(), value)); + } + + assert_eq!(self.out_scratch.len(), 0); + self.hpack_enc + .encode_into(headers, &mut self.out_scratch) + .map_err(H2ConnectionError::WriteError)?; + let payload = self.out_scratch.take_all(); + + self.write_frame(frame, payload).await?; + } + H2EventPayload::BodyChunk(chunk) => { + let flags = BitFlags::::default(); + let frame = Frame::new(FrameType::Data(flags), ev.stream_id); + + self.write_frame(frame, chunk).await?; + } + H2EventPayload::BodyEnd => { + // FIXME: this should transition the stream to `Closed` + // state (or at the very least `HalfClosedLocal`). + // Either way, whoever owns the stream state should know + // about it, cf. https://github.com/hapsoc/fluke/issues/123 + + let flags = DataFlags::EndStream; + let frame = Frame::new(FrameType::Data(flags.into()), ev.stream_id); + self.write_frame(frame, Roll::empty()).await?; + } + } + + Ok(()) + } + + async fn write_frame( + &mut self, + frame: Frame, + payload: impl Into, + ) -> Result<(), H2ConnectionError> { + let payload = payload.into(); + + match &frame.frame_type { + FrameType::Data(headers) => { + if headers.contains(DataFlags::EndStream) { + // if the stream is open, this transitions to HalfClosedLocal. + if let Some(ss) = self.state.streams.get_mut(&frame.stream_id) { + match ss { + StreamState::Open(_) => { + // transition through StreamState::HalfClosedRemote + // so we don't have to remove/re-insert. + let mut entry = StreamState::HalfClosedRemote; + std::mem::swap(&mut entry, ss); + + let body_tx = match entry { + StreamState::Open(body_tx) => body_tx, + _ => unreachable!(), + }; + + *ss = StreamState::HalfClosedLocal(body_tx); + } + _ => { + // transition to closed + self.state.streams.remove(&frame.stream_id); + } + } + } + } + } + _ => { + // muffin. + } + } + + let frame_roll = frame.into_roll(&mut self.out_scratch)?; + + if payload.is_empty() { + self.transport_w + .write_all(frame_roll) + .await + .map_err(H2ConnectionError::WriteError)?; + } else { + self.transport_w + .writev_all(PieceList::default().with(frame_roll).with(payload)) + .await + .map_err(H2ConnectionError::WriteError)?; + } + + Ok(()) + } + + async fn process_frame( + &mut self, + frame: Frame, + mut payload: Roll, + rx: &mut mpsc::Receiver<(Frame, Roll)>, + ) -> Result<(), H2ConnectionError> { + match frame.frame_type { + FrameType::Data(flags) => { + let ss = self.state.streams.get_mut(&frame.stream_id).ok_or( + H2ConnectionError::StreamClosed { + stream_id: frame.stream_id, + }, + )?; + + match ss { + StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { + if body_tx + .send(Ok(PieceOrTrailers::Piece(payload.into()))) + .await + .is_err() + { + warn!("TODO: The body is being ignored, we should reset the stream"); + } + + if flags.contains(DataFlags::EndStream) { + // if we're HalfClosedLocal, this transitions to Closed + // otherwise, it transitions to HalfClosedRemote + if matches!(ss, StreamState::Open(_)) { + *ss = StreamState::HalfClosedRemote; + } else { + self.state.streams.remove(&frame.stream_id); + } + } + } + StreamState::HalfClosedRemote => { + debug!( + stream_id = %frame.stream_id, + "Received data for closed stream" + ); + self.rst(frame.stream_id, H2StreamError::StreamClosed) + .await?; + } + } + } + FrameType::Headers(flags) => { + if flags.contains(HeadersFlags::Priority) { + let pri_spec; + (payload, pri_spec) = PrioritySpec::parse(payload) + .finish() + .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + debug!(exclusive = %pri_spec.exclusive, stream_dependency = ?pri_spec.stream_dependency, weight = %pri_spec.weight, "received priority, exclusive"); + + if pri_spec.stream_dependency == frame.stream_id { + return Err(H2ConnectionError::HeadersInvalidPriority { + stream_id: frame.stream_id, + }); + } + } + + let headers_or_trailers; + let mode; + + match self.state.streams.get_mut(&frame.stream_id) { + None => { + headers_or_trailers = HeadersOrTrailers::Headers; + debug!( + stream_id = %frame.stream_id, + last_stream_id = %self.state.last_stream_id, + next_stream_count = %self.state.streams.len() + 1, + "Receiving headers", + ); + + if frame.stream_id.is_server_initiated() { + return Err(H2ConnectionError::ClientSidShouldBeOdd); + } + + if frame.stream_id <= self.state.last_stream_id { + debug!( + frame_stream_id = %frame.stream_id, + last_stream_id = %self.state.last_stream_id, + "Received headers for invalid stream ID" + ); + + // this stream may have existed, but it no longer does: + return Err(H2ConnectionError::StreamClosed { + stream_id: frame.stream_id, + }); + } else { + // TODO: if we're shutting down, ignore streams higher + // than the last one we accepted. + + let max_concurrent_streams = + self.state.self_settings.max_concurrent_streams; + let num_streams_if_accept = self.state.streams.len() + 1; + if num_streams_if_accept > max_concurrent_streams as _ { + // reset the stream, indicating we refused it + self.rst(frame.stream_id, H2StreamError::RefusedStream) + .await?; + + // but we still need to skip over any continuation frames + mode = ReadHeadersMode::Skip; + } else { + self.state.last_stream_id = frame.stream_id; + mode = ReadHeadersMode::Process; + } + } + } + Some(StreamState::Open(_) | StreamState::HalfClosedLocal(_)) => { + headers_or_trailers = HeadersOrTrailers::Trailers; + debug!("Receiving trailers for stream {}", frame.stream_id); + + if flags.contains(HeadersFlags::EndStream) { + // good, that's what we expect + mode = ReadHeadersMode::Process; + } else { + // ignore trailers, we're not accepting the stream + mode = ReadHeadersMode::Skip; + + self.rst(frame.stream_id, H2StreamError::TrailersNotEndStream) + .await?; + } + } + Some(StreamState::HalfClosedRemote) => { + return Err(H2ConnectionError::StreamClosed { + stream_id: frame.stream_id, + }); + } + } + + self.read_headers( + headers_or_trailers, + mode, + flags, + frame.stream_id, + payload, + rx, + ) + .await?; + } + FrameType::Priority => { + let pri_spec = match PrioritySpec::parse(payload) { + Ok((_rest, pri_spec)) => pri_spec, + Err(e) => { + todo!("handle connection error: invalid priority frame {e}") + } + }; + debug!(?pri_spec, "received priority frame"); + + if pri_spec.stream_dependency == frame.stream_id { + return Err(H2ConnectionError::HeadersInvalidPriority { + stream_id: frame.stream_id, + }); + } + } + // note: this always unconditionally transitions the stream to closed + FrameType::RstStream => match self.state.streams.remove(&frame.stream_id) { + None => { + return Err(H2ConnectionError::RstStreamForUnknownStream { + stream_id: frame.stream_id, + }) + } + Some(ss) => match ss { + StreamState::Open(body_tx) | StreamState::HalfClosedLocal(body_tx) => { + _ = body_tx + .send(Err(H2StreamError::ReceivedRstStream.into())) + .await; + } + StreamState::HalfClosedRemote => { + // good + } + }, + }, + FrameType::Settings(s) => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::SettingsWithNonZeroStreamId { + stream_id: frame.stream_id, + }); + } + + if s.contains(SettingsFlags::Ack) { + debug!("Peer has acknowledged our settings, cool"); + if !payload.is_empty() { + return Err(H2ConnectionError::SettingsAckWithPayload { + len: payload.len() as _, + }); + } + } else { + let (_, settings) = Settings::parse(payload) + .finish() + .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + self.hpack_enc + .set_max_table_size(settings.header_table_size as usize); + + debug!(?settings, "Received settings"); + self.state.peer_settings = settings; + + let frame = Frame::new( + FrameType::Settings(SettingsFlags::Ack.into()), + StreamId::CONNECTION, + ); + self.write_frame(frame, Roll::empty()).await?; + } + } + FrameType::PushPromise => { + return Err(H2ConnectionError::ClientSentPushPromise); + } + FrameType::Ping(flags) => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::PingFrameWithNonZeroStreamId { + stream_id: frame.stream_id, + }); + } + + if frame.len != 8 { + return Err(H2ConnectionError::PingFrameInvalidLength { len: frame.len }); + } + + if flags.contains(PingFlags::Ack) { + // TODO: check that payload matches the one we sent? + return Ok(()); + } + + // send pong frame + let flags = PingFlags::Ack.into(); + let frame = Frame::new(FrameType::Ping(flags), StreamId::CONNECTION) + .with_len(payload.len() as u32); + self.write_frame(frame, payload).await?; + } + FrameType::GoAway => { + if frame.stream_id != StreamId::CONNECTION { + return Err(H2ConnectionError::GoAwayWithNonZeroStreamId { + stream_id: frame.stream_id, + }); + } + + self.goaway_recv = true; + + // TODO: this should probably have other effects than setting + // this flag. + } + FrameType::WindowUpdate => { + if payload.len() != 4 { + return Err(H2ConnectionError::WindowUpdateInvalidLength { + len: payload.len() as _, + }); + } + + let increment; + (_, (_, increment)) = parse_reserved_and_u31(payload) + .finish() + .map_err(|err| eyre::eyre!("parsing error: {err:?}"))?; + + if increment == 0 { + return Err(H2ConnectionError::WindowUpdateZeroIncrement); + } + + if frame.stream_id == StreamId::CONNECTION { + debug!("TODO: ignoring connection-wide window update"); + } else { + match self.state.streams.get_mut(&frame.stream_id) { + None => { + return Err(H2ConnectionError::WindowUpdateForUnknownStream { + stream_id: frame.stream_id, + }); + } + Some(_ss) => { + debug!("TODO: handle window update for stream {}", frame.stream_id) + } + } + } + } + FrameType::Continuation(_flags) => { + return Err(H2ConnectionError::UnexpectedContinuationFrame { + stream_id: frame.stream_id, + }); + } + FrameType::Unknown(ft) => { + trace!( + "ignoring unknown frame with type 0x{:x}, flags 0x{:x}", + ft.ty, + ft.flags + ); + } + } + + Ok(()) + } + + /// Send a RST_STREAM frame to the peer. + async fn rst( + &mut self, + stream_id: StreamId, + e: H2StreamError, + ) -> Result<(), H2ConnectionError> { + self.state.streams.remove(&stream_id); + + let error_code = e.as_known_error_code(); + debug!("Sending rst because: {e} (known error code: {error_code:?})"); + + debug!(%stream_id, ?error_code, "Sending RstStream"); + let payload = self.out_scratch.put_to_roll(4, |mut slice| { + slice.write_u32::(error_code.repr())?; + Ok(()) + })?; + + let frame = Frame::new(FrameType::RstStream, stream_id) + .with_len((payload.len()).try_into().unwrap()); + self.write_frame(frame, payload).await?; + + Ok(()) + } + + async fn read_headers( + &mut self, + headers_or_trailers: HeadersOrTrailers, + mode: ReadHeadersMode, + flags: BitFlags, + stream_id: StreamId, + payload: Roll, + rx: &mut mpsc::Receiver<(Frame, Roll)>, + ) -> Result<(), H2ConnectionError> { + let end_stream = flags.contains(HeadersFlags::EndStream); + + enum Data { + Single(Roll), + Multi(SmallVec<[Roll; 2]>), + } + + let data = if flags.contains(HeadersFlags::EndHeaders) { + // good, no continuation frames needed + Data::Single(payload) + } else { + // read continuation frames + + #[allow(unused, clippy::let_unit_value)] + let flags = (); // don't accidentally use the `flags` variable + + let mut fragments = smallvec![payload]; + + loop { + let (continuation_frame, continuation_payload) = match rx.recv().await { + Some(t) => t, + None => { + // even though this error is "for a stream", it's a + // connection error, because it means the peer doesn't + // know how to speak HTTP/2. + return Err(H2ConnectionError::ExpectedContinuationFrame { + stream_id, + frame_type: None, + }); + } + }; + + if stream_id != continuation_frame.stream_id { + return Err(H2ConnectionError::ExpectedContinuationForStream { + stream_id, + continuation_stream_id: continuation_frame.stream_id, + }); + } + + let cont_flags = match continuation_frame.frame_type { + FrameType::Continuation(flags) => flags, + other => { + return Err(H2ConnectionError::ExpectedContinuationFrame { + stream_id, + frame_type: Some(other), + }) + } + }; + + // add fragment + fragments.push(continuation_payload); + + if cont_flags.contains(ContinuationFlags::EndHeaders) { + // we're done + break; + } + } + + Data::Multi(fragments) + }; + + if matches!(mode, ReadHeadersMode::Skip) { + // that's all we need to do: we're not actually validating the + // headers, we already send a RST + return Ok(()); + } + + let mut method: Option = None; + let mut scheme: Option = None; + let mut path: Option = None; + let mut authority: Option = None; + + let mut headers = Headers::default(); + + // TODO: find a way to propagate errors from here - probably will have to change + // the function signature in fluke-hpack, or just write to some captured + // error + let on_header_pair = |key: Cow<[u8]>, value: Cow<[u8]>| { + debug!( + "{headers_or_trailers:?} | {}: {}", + std::str::from_utf8(&key).unwrap_or(""), // TODO: does this hurt performance when debug logging is disabled? + std::str::from_utf8(&value).unwrap_or(""), + ); + + if &key[..1] == b":" { + if matches!(headers_or_trailers, HeadersOrTrailers::Trailers) { + // TODO: proper error handling + panic!("trailers cannot contain pseudo-headers"); + } + + // it's a pseudo-header! + // TODO: reject headers that occur after pseudo-headers + match &key[1..] { + b"method" => { + // TODO: error handling + let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); + if method.replace(Method::from(value)).is_some() { + unreachable!(); // No duplicate allowed. + } + } + b"scheme" => { + // TODO: error handling + let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); + if scheme.replace(value.parse().unwrap()).is_some() { + unreachable!(); // No duplicate allowed. + } + } + b"path" => { + // TODO: error handling + let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); + if value.len() == 0 || path.replace(value).is_some() { + unreachable!(); // No empty path nor duplicate allowed. + } + } + b"authority" => { + // TODO: error handling + let value: PieceStr = Piece::from(value.to_vec()).to_str().unwrap(); + if authority.replace(value.parse().unwrap()).is_some() { + unreachable!(); // No duplicate allowed. (h2spec doesn't seem to test for + // this case but rejecting duplicates seems reasonable.) + } + } + _ => { + debug!("ignoring pseudo-header"); + } + } + } else { + // TODO: what do we do in case of malformed header names? + // ignore it? return a 400? + let name = HeaderName::from_bytes(&key[..]).expect("malformed header name"); + let value: Piece = value.to_vec().into(); + headers.append(name, value); + } + }; + + match data { + Data::Single(payload) => { + self.hpack_dec + .decode_with_cb(&payload[..], on_header_pair) + .map_err(|e| H2ConnectionError::CompressionError(format!("{e:?}")))?; + } + Data::Multi(fragments) => { + let total_len = fragments.iter().map(|f| f.len()).sum(); + // this is a slow path, let's do a little heap allocation. we could + // be using `RollMut` for this, but it would probably need to resize + // a bunch + let mut payload = Vec::with_capacity(total_len); + for frag in &fragments { + payload.extend_from_slice(&frag[..]); + } + self.hpack_dec + .decode_with_cb(&payload[..], on_header_pair) + .map_err(|e| H2ConnectionError::CompressionError(format!("{e:?}")))?; + } + }; + + match headers_or_trailers { + HeadersOrTrailers::Headers => { + // TODO: cf. https://httpwg.org/specs/rfc9113.html#HttpRequest + // A server SHOULD treat a request as malformed if it contains a Host header + // field that identifies an entity that differs from the entity in the + // ":authority" pseudo-header field. + + // TODO: proper error handling (return 400) + let method = method.unwrap(); + let scheme = scheme.unwrap(); + + let path = path.unwrap(); + let path_and_query: PathAndQuery = path.parse().unwrap(); + + let authority = match authority { + Some(authority) => Some(authority), + None => headers + .get(header::HOST) + .map(|host| host.as_str().unwrap().parse().unwrap()), + }; + + let mut uri_parts: http::uri::Parts = Default::default(); + uri_parts.scheme = Some(scheme); + uri_parts.authority = authority; + uri_parts.path_and_query = Some(path_and_query); + + let uri = http::uri::Uri::from_parts(uri_parts).unwrap(); + + let req = Request { + method, + uri, + version: Version::HTTP_2, + headers, + }; + + let responder = Responder { + encoder: H2Encoder { + stream_id, + tx: self.ev_tx.clone(), + state: EncoderState::ExpectResponseHeaders, + }, + // TODO: why tf is this state encoded twice? is that really + // necessary? I know it's for typestates and H2Encoder needs + // to look up its state at runtime I guess, but.. that's not great? + state: ExpectResponseHeaders, + }; + + let (piece_tx, piece_rx) = mpsc::channel::(1); // TODO: is 1 a sensible value here? + + let req_body = H2Body { + // FIXME: that's not right. h2 requests can still specify + // a content-length + content_length: if end_stream { Some(0) } else { None }, + eof: end_stream, + rx: piece_rx, + }; + + fluke_maybe_uring::spawn({ + let driver = self.driver.clone(); + async move { + let mut req_body = req_body; + let responder = responder; + + match driver.handle(req, &mut req_body, responder).await { + Ok(_responder) => { + debug!("Handler completed successfully, gave us a responder"); + } + Err(e) => { + // TODO: actually handle that error. + debug!("Handler returned an error: {e}") + } + } + } + }); + + self.state.streams.insert( + stream_id, + if end_stream { + StreamState::HalfClosedRemote + } else { + StreamState::Open(piece_tx) + }, + ); + } + HeadersOrTrailers::Trailers => { + match self.state.streams.get_mut(&stream_id) { + Some(StreamState::Open(body_tx)) => { + if body_tx + .send(Ok(PieceOrTrailers::Trailers(Box::new(headers)))) + .await + .is_err() + { + // the body is being ignored, but there's no point in + // resetting the stream since we just got the end of it + } + } + _ => { + unreachable!("stream state should be open when we receive trailers") + } + } + self.state.streams.remove(&stream_id); + } + } + + Ok(()) + } +} + +enum ReadHeadersMode { + // we're accepting the stream or processing trailers, we want to + // process the headers we read. + Process, + // we're refusing the stream, we want to skip over the headers we read. + Skip, +} From 4a0bff2b9f43cacb7e100585acdbf84fefa0ea7a Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 15:44:39 +0100 Subject: [PATCH 13/17] better debug implementation for Frame --- crates/fluke/src/h2/parse.rs | 53 ++++++++++++++++++++++++++++++++++- crates/fluke/src/h2/server.rs | 8 ++++-- crates/fluke/src/h2/types.rs | 11 -------- 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/crates/fluke/src/h2/parse.rs b/crates/fluke/src/h2/parse.rs index 832cb5d8..549896bf 100644 --- a/crates/fluke/src/h2/parse.rs +++ b/crates/fluke/src/h2/parse.rs @@ -210,7 +210,6 @@ impl fmt::Display for StreamId { } /// See https://httpwg.org/specs/rfc9113.html#FrameHeader -#[derive(Debug)] pub struct Frame { pub frame_type: FrameType, pub reserved: u8, @@ -218,6 +217,58 @@ pub struct Frame { pub len: u32, } +impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = match &self.frame_type { + FrameType::Data(flags) => { + let mut s = f.debug_struct("Frame:Data"); + s.field("flags", flags); + s + } + FrameType::Headers(flags) => { + let mut s = f.debug_struct("Frame:Headers"); + s.field("flags", flags); + s + } + FrameType::Priority => f.debug_struct("Frame:Priority"), + FrameType::RstStream => f.debug_struct("Frame:RstStream"), + FrameType::Settings(flags) => { + let mut s = f.debug_struct("Frame:Settings"); + s.field("flags", flags); + s + } + FrameType::PushPromise => f.debug_struct("Frame:PushPromise"), + FrameType::Ping(flags) => { + let mut s = f.debug_struct("Frame:Ping"); + s.field("flags", flags); + s + } + FrameType::GoAway => f.debug_struct("Frame:GoAway"), + FrameType::WindowUpdate => f.debug_struct("Frame:WindowUpdate"), + FrameType::Continuation(flags) => { + let mut s = f.debug_struct("Frame:Continuation"); + s.field("flags", flags); + s + } + FrameType::Unknown(eft) => { + let mut s = f.debug_struct("Frame:Unknown"); + s.field("encoded_frame_type", eft); + s + } + }; + if self.reserved != 0 { + s.field("reserved", &self.reserved); + } + if self.stream_id != StreamId::CONNECTION { + s.field("stream_id", &self.stream_id); + } + if self.len > 0 { + s.field("len", &self.len); + } + s.finish() + } +} + impl Frame { /// Create a new frame with the given type and stream ID. pub fn new(frame_type: FrameType, stream_id: StreamId) -> Self { diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 16929fa0..5e5dd6b2 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -18,7 +18,7 @@ use http::{ use nom::Finish; use smallvec::{smallvec, SmallVec}; use tokio::sync::mpsc; -use tracing::{debug, trace, warn}; +use tracing::{debug, trace}; use crate::{ h2::{ @@ -260,7 +260,7 @@ impl ServerContext { } }; - debug!(?frame, "Received"); + debug!(?frame, "<"); let max_frame_size = max_frame_size.load(Ordering::Relaxed); if frame.len > max_frame_size { @@ -408,6 +408,8 @@ impl ServerContext { frame: Frame, payload: impl Into, ) -> Result<(), H2ConnectionError> { + debug!(?frame, ">"); + let payload = payload.into(); match &frame.frame_type { @@ -480,7 +482,7 @@ impl ServerContext { .await .is_err() { - warn!("TODO: The body is being ignored, we should reset the stream"); + debug!("TODO: The body is being ignored, we should reset the stream"); } if flags.contains(DataFlags::EndStream) { diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index d82d7450..f2ca01c2 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -261,14 +261,3 @@ impl fmt::Debug for H2EventPayload { #[derive(thiserror::Error, Debug)] #[error("the peer closed the connection unexpectedly")] pub(crate) struct ConnectionClosed; - -fn is_peer_gone(e: &eyre::Report) -> bool { - if let Some(io_error) = e.root_cause().downcast_ref::() { - matches!( - io_error.kind(), - std::io::ErrorKind::BrokenPipe | std::io::ErrorKind::ConnectionReset - ) - } else { - false - } -} From 1099668ae43a8b6e39113188246130b2fc72c36c Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 16:02:28 +0100 Subject: [PATCH 14/17] Uhm --- crates/fluke/src/h2/parse.rs | 93 ++++++++++++++++++++++------------- crates/fluke/src/h2/server.rs | 90 +++++++++++++++------------------ 2 files changed, 100 insertions(+), 83 deletions(-) diff --git a/crates/fluke/src/h2/parse.rs b/crates/fluke/src/h2/parse.rs index 549896bf..018ecfe5 100644 --- a/crates/fluke/src/h2/parse.rs +++ b/crates/fluke/src/h2/parse.rs @@ -219,52 +219,77 @@ pub struct Frame { impl fmt::Debug for Frame { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut s = match &self.frame_type { + if self.stream_id.0 == 0 { + write!(f, "Conn:")?; + } else { + write!(f, "#{}:", self.stream_id.0)?; + } + + let name = match &self.frame_type { + FrameType::Data(_) => "Data", + FrameType::Headers(_) => "Headers", + FrameType::Priority => "Priority", + FrameType::RstStream => "RstStream", + FrameType::Settings(_) => "Settings", + FrameType::PushPromise => "PushPromise", + FrameType::Ping(_) => "Ping", + FrameType::GoAway => "GoAway", + FrameType::WindowUpdate => "WindowUpdate", + FrameType::Continuation(_) => "Continuation", + FrameType::Unknown(EncodedFrameType { ty, flags }) => { + return write!(f, "UnknownFrame({:#x}, {:#x})", ty, flags) + } + }; + let mut s = f.debug_struct(name); + + if self.reserved != 0 { + s.field("reserved", &self.reserved); + } + if self.len > 0 { + s.field("len", &self.len); + } + + // now write flags with DisplayDebug + struct DisplayDebug<'a, D: fmt::Display>(&'a D); + impl<'a, D: fmt::Display> fmt::Debug for DisplayDebug<'a, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self.0, f) + } + } + + // for all the variants with flags, add a flags field, of value + // &DisplayDebug(flags) + match &self.frame_type { FrameType::Data(flags) => { - let mut s = f.debug_struct("Frame:Data"); - s.field("flags", flags); - s + if !flags.is_empty() { + s.field("flags", &DisplayDebug(flags)); + } } FrameType::Headers(flags) => { - let mut s = f.debug_struct("Frame:Headers"); - s.field("flags", flags); - s + if !flags.is_empty() { + s.field("flags", &DisplayDebug(flags)); + } } - FrameType::Priority => f.debug_struct("Frame:Priority"), - FrameType::RstStream => f.debug_struct("Frame:RstStream"), FrameType::Settings(flags) => { - let mut s = f.debug_struct("Frame:Settings"); - s.field("flags", flags); - s + if !flags.is_empty() { + s.field("flags", &DisplayDebug(flags)); + } } - FrameType::PushPromise => f.debug_struct("Frame:PushPromise"), FrameType::Ping(flags) => { - let mut s = f.debug_struct("Frame:Ping"); - s.field("flags", flags); - s + if !flags.is_empty() { + s.field("flags", &DisplayDebug(flags)); + } } - FrameType::GoAway => f.debug_struct("Frame:GoAway"), - FrameType::WindowUpdate => f.debug_struct("Frame:WindowUpdate"), FrameType::Continuation(flags) => { - let mut s = f.debug_struct("Frame:Continuation"); - s.field("flags", flags); - s + if !flags.is_empty() { + s.field("flags", &DisplayDebug(flags)); + } } - FrameType::Unknown(eft) => { - let mut s = f.debug_struct("Frame:Unknown"); - s.field("encoded_frame_type", eft); - s + _ => { + // muffin } - }; - if self.reserved != 0 { - s.field("reserved", &self.reserved); - } - if self.stream_id != StreamId::CONNECTION { - s.field("stream_id", &self.stream_id); - } - if self.len > 0 { - s.field("len", &self.len); } + s.finish() } } diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 5e5dd6b2..b88af0d5 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -49,49 +49,15 @@ impl Default for ServerConf { } pub async fn serve( - (mut transport_r, mut transport_w): (impl ReadOwned, impl WriteOwned), + (transport_r, transport_w): (impl ReadOwned, impl WriteOwned), conf: Rc, - mut client_buf: RollMut, + client_buf: RollMut, driver: Rc, ) -> eyre::Result<()> { let mut state = ConnState::default(); state.self_settings.max_concurrent_streams = conf.max_streams; - (client_buf, _) = match read_and_parse( - parse::preface, - &mut transport_r, - client_buf, - parse::PREFACE.len(), - ) - .await? - { - Some((client_buf, frame)) => (client_buf, frame), - None => { - debug!("h2 client closed connection before sending preface"); - return Ok(()); - } - }; - debug!("read preface"); - - let mut out_scratch = RollMut::alloc()?; - - // we have to send an initial settings frame - { - let payload_roll = state.self_settings.into_roll(&mut out_scratch)?; - let frame_roll = Frame::new( - FrameType::Settings(Default::default()), - StreamId::CONNECTION, - ) - .with_len(payload_roll.len().try_into().unwrap()) - .into_roll(&mut out_scratch)?; - - transport_w - .writev_all(vec![frame_roll, payload_roll]) - .await?; - debug!("sent settings frame"); - } - - ServerContext::new(driver.clone(), state, transport_w, out_scratch) + ServerContext::new(driver.clone(), state, transport_w)? .work(client_buf, transport_r) .await?; @@ -119,12 +85,7 @@ pub(crate) struct ServerContext { } impl ServerContext { - pub(crate) fn new( - driver: Rc, - state: ConnState, - transport_w: W, - out_scratch: RollMut, - ) -> Self { + pub(crate) fn new(driver: Rc, state: ConnState, transport_w: W) -> eyre::Result { let mut hpack_dec = fluke_hpack::Decoder::new(); hpack_dec .set_max_allowed_table_size(Settings::default().header_table_size.try_into().unwrap()); @@ -133,25 +94,53 @@ impl ServerContext { let (ev_tx, ev_rx) = tokio::sync::mpsc::channel::(32); - Self { + Ok(Self { driver, ev_tx, ev_rx, state, hpack_dec, hpack_enc, - out_scratch, + out_scratch: RollMut::alloc()?, goaway_recv: false, transport_w, - } + }) } /// Reads and process h2 frames from the client. pub(crate) async fn work( mut self, - client_buf: RollMut, - transport_r: impl ReadOwned, + mut client_buf: RollMut, + mut transport_r: impl ReadOwned, ) -> eyre::Result<()> { + // first read the preface + { + (client_buf, _) = match read_and_parse( + parse::preface, + &mut transport_r, + client_buf, + parse::PREFACE.len(), + ) + .await? + { + Some((client_buf, frame)) => (client_buf, frame), + None => { + debug!("h2 client closed connection before sending preface"); + return Ok(()); + } + }; + } + + // then send our initial settings + { + let payload = self.state.self_settings.into_roll(&mut self.out_scratch)?; + let frame = Frame::new( + FrameType::Settings(Default::default()), + StreamId::CONNECTION, + ); + self.write_frame(frame, payload).await?; + } + let mut goaway_err: Option = None; { @@ -439,6 +428,9 @@ impl ServerContext { } } } + FrameType::Settings(_) => { + // TODO: keep track of whether our new settings have been acknowledged + } _ => { // muffin. } @@ -654,7 +646,7 @@ impl ServerContext { self.hpack_enc .set_max_table_size(settings.header_table_size as usize); - debug!(?settings, "Received settings"); + debug!("Peer sent us {settings:#?}"); self.state.peer_settings = settings; let frame = Frame::new( From bdaf0fcba7e45deb94011cfd3d45179b17886ff8 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 16:28:09 +0100 Subject: [PATCH 15/17] mh --- crates/fluke-buffet/src/roll.rs | 2 +- crates/fluke/src/h2/server.rs | 13 +++++++++++-- crates/fluke/src/util.rs | 12 ++++-------- test-crates/fluke-h2spec/src/main.rs | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/crates/fluke-buffet/src/roll.rs b/crates/fluke-buffet/src/roll.rs index 55a4db03..7566c642 100644 --- a/crates/fluke-buffet/src/roll.rs +++ b/crates/fluke-buffet/src/roll.rs @@ -255,7 +255,7 @@ impl RollMut { init: 0, }; let (res, mut read_into) = r.read(read_into).await; - tracing::trace!(init = %read_into.init, "read_into done!"); + tracing::trace!("read_into got {} bytes", read_into.init); read_into.buf.len += read_into.init; (res, read_into.buf) } diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index b88af0d5..6f8d4e3c 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -115,6 +115,7 @@ impl ServerContext { ) -> eyre::Result<()> { // first read the preface { + debug!("Reading preface"); (client_buf, _) = match read_and_parse( parse::preface, &mut transport_r, @@ -129,10 +130,12 @@ impl ServerContext { return Ok(()); } }; + debug!("Reading preface: done"); } // then send our initial settings { + debug!("Sending initial settings"); let payload = self.state.self_settings.into_roll(&mut self.out_scratch)?; let frame = Frame::new( FrameType::Settings(Default::default()), @@ -159,6 +162,8 @@ impl ServerContext { )); let mut process_task = std::pin::pin!(self.process_loop(rx)); + debug!("Starting both deframe & process tasks"); + tokio::select! { res = &mut deframe_task => { debug!(?res, "h2 deframe task finished"); @@ -229,6 +234,7 @@ impl ServerContext { 'read_frames: loop { const MAX_FRAME_HEADER_SIZE: usize = 128; let frame; + debug!("Reading frame... Buffer length: {}", client_buf.len()); let frame_res = read_and_parse( Frame::parse, &mut transport_r, @@ -244,11 +250,14 @@ impl ServerContext { (client_buf, frame) = match maybe_frame { Some((client_buf, frame)) => (client_buf, frame), None => { - debug!("Peer went away before sending a frame"); + debug!("Peer hung up"); break 'read_frames; } }; - + debug!( + "Reading frame... done! New buffer length: {}", + client_buf.len() + ); debug!(?frame, "<"); let max_frame_size = max_frame_size.load(Ordering::Relaxed); diff --git a/crates/fluke/src/util.rs b/crates/fluke/src/util.rs index b3aef610..65d4ba2a 100644 --- a/crates/fluke/src/util.rs +++ b/crates/fluke/src/util.rs @@ -18,11 +18,7 @@ where Parser: Fn(Roll) -> IResult, { loop { - trace!( - "reading+parsing (buf.len={}, buf.cap={})", - buf.len(), - buf.cap() - ); + trace!("Running parser (len={}, cap={})", buf.len(), buf.cap()); let filled = buf.filled(); match parser(filled) { @@ -34,7 +30,7 @@ where if err.is_incomplete() { { trace!( - "incomplete request, need more data. start of buffer: {:?}", + "need more data. so far, we have:\n{:?}", &buf[..std::cmp::min(buf.len(), 128)].hex_dump() ); } @@ -50,9 +46,9 @@ where buf.reserve()?; } trace!( - "calling read_into, buf.cap={}, buf.len={} read_limit={read_limit}", + "Calling read_into (len={}, cap={}, read_limit={read_limit})", + buf.len(), buf.cap(), - buf.len() ); (res, buf) = buf.read_into(read_limit, stream).await; diff --git a/test-crates/fluke-h2spec/src/main.rs b/test-crates/fluke-h2spec/src/main.rs index 139c83eb..636f538a 100644 --- a/test-crates/fluke-h2spec/src/main.rs +++ b/test-crates/fluke-h2spec/src/main.rs @@ -20,7 +20,7 @@ fn main() { eprintln!("Couldn't parse RUST_LOG: {e}"); EnvFilter::try_new("info").unwrap() })) - .without_time() + // .without_time() .init(); let h2spec_binary = match which::which("h2spec") { From cff1f9923160decb6330444d54295641c8f94ed3 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 16:37:01 +0100 Subject: [PATCH 16/17] well every time I lose silly time to this I improve debug logging, so. --- crates/fluke/src/h2/parse.rs | 1 + crates/fluke/src/h2/server.rs | 24 ++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/crates/fluke/src/h2/parse.rs b/crates/fluke/src/h2/parse.rs index 018ecfe5..a4717028 100644 --- a/crates/fluke/src/h2/parse.rs +++ b/crates/fluke/src/h2/parse.rs @@ -579,6 +579,7 @@ impl Settings { const MAX_FRAME_SIZE_ALLOWED_RANGE: RangeInclusive = (1 << 14)..=((1 << 24) - 1); pub fn parse(mut i: Roll) -> IResult { + tracing::trace!("parsing settings frame, roll length: {}", i.len()); let mut settings = Self::default(); while !i.is_empty() { diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 6f8d4e3c..61f9d8bb 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -269,6 +269,11 @@ impl ServerContext { }); } + debug!( + "Reading payload of size {}... Buffer length: {}", + frame.len, + client_buf.len() + ); let mut payload; (client_buf, payload) = match read_and_parse( nom::bytes::streaming::take(frame.len as usize), @@ -286,6 +291,10 @@ impl ServerContext { }) } }; + debug!( + "Reading payload... done! New buffer length: {}", + client_buf.len() + ); let has_padding = match frame.frame_type { FrameType::Data(flags) => flags.contains(DataFlags::Padded), @@ -403,11 +412,10 @@ impl ServerContext { async fn write_frame( &mut self, - frame: Frame, + mut frame: Frame, payload: impl Into, ) -> Result<(), H2ConnectionError> { debug!(?frame, ">"); - let payload = payload.into(); match &frame.frame_type { @@ -445,14 +453,25 @@ impl ServerContext { } } + // TODO: enforce max_frame_size from the peer settings, not just u32::max + frame.len = payload + .len() + .try_into() + .map_err(|_| H2ConnectionError::FrameTooLarge { + frame_type: frame.frame_type, + frame_size: payload.len() as _, + max_frame_size: u32::MAX, + })?; let frame_roll = frame.into_roll(&mut self.out_scratch)?; if payload.is_empty() { + trace!("Writing frame without payload"); self.transport_w .write_all(frame_roll) .await .map_err(H2ConnectionError::WriteError)?; } else { + trace!("Writing frame with payload"); self.transport_w .writev_all(PieceList::default().with(frame_roll).with(payload)) .await @@ -663,6 +682,7 @@ impl ServerContext { StreamId::CONNECTION, ); self.write_frame(frame, Roll::empty()).await?; + debug!("Acknowledged peer settings"); } } FrameType::PushPromise => { From 0a0423467f1b6b2d3c29ccb6961848bc7a8a1a3d Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Tue, 12 Mar 2024 16:40:40 +0100 Subject: [PATCH 17/17] Fix more cases --- crates/fluke/src/h2/server.rs | 30 +++++++++++++++++------------- crates/fluke/src/h2/types.rs | 6 ++++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/crates/fluke/src/h2/server.rs b/crates/fluke/src/h2/server.rs index 61f9d8bb..10ca4aa9 100644 --- a/crates/fluke/src/h2/server.rs +++ b/crates/fluke/src/h2/server.rs @@ -340,12 +340,8 @@ impl ServerContext { ) -> Result<(), H2ConnectionError> { loop { tokio::select! { - ev = self.ev_rx.recv() => { - match ev { - Some(ev) => self.handle_event(ev).await?, - None => unreachable!("the context owns a copy of the sender, and this method has &mut self, so the sender can't be dropped while this method is running"), - } - }, + biased; + maybe_frame = rx.recv() => { if let Some((frame, payload)) = maybe_frame { self.process_frame(frame, payload, &mut rx).await?; @@ -354,6 +350,13 @@ impl ServerContext { break; } } + + ev = self.ev_rx.recv() => { + match ev { + Some(ev) => self.handle_event(ev).await?, + None => unreachable!("the context owns a copy of the sender, and this method has &mut self, so the sender can't be dropped while this method is running"), + } + }, } } @@ -557,14 +560,15 @@ impl ServerContext { return Err(H2ConnectionError::ClientSidShouldBeOdd); } - if frame.stream_id <= self.state.last_stream_id { - debug!( - frame_stream_id = %frame.stream_id, - last_stream_id = %self.state.last_stream_id, - "Received headers for invalid stream ID" + if frame.stream_id < self.state.last_stream_id { + // we're going back? we can't. + return Err( + H2ConnectionError::ClientSidShouldBeNumericallyIncreasing { + stream_id: frame.stream_id, + last_stream_id: self.state.last_stream_id, + }, ); - - // this stream may have existed, but it no longer does: + } else if frame.stream_id == self.state.last_stream_id { return Err(H2ConnectionError::StreamClosed { stream_id: frame.stream_id, }); diff --git a/crates/fluke/src/h2/types.rs b/crates/fluke/src/h2/types.rs index f2ca01c2..be58b18d 100644 --- a/crates/fluke/src/h2/types.rs +++ b/crates/fluke/src/h2/types.rs @@ -101,6 +101,12 @@ pub(crate) enum H2ConnectionError { #[error("client tried to initiate an even-numbered stream")] ClientSidShouldBeOdd, + #[error("client stream IDs should be numerically increasing")] + ClientSidShouldBeNumericallyIncreasing { + stream_id: StreamId, + last_stream_id: StreamId, + }, + #[error("received {frame_type:?} frame with Padded flag but empty payload")] PaddedFrameEmpty { frame_type: FrameType },