diff --git a/lib/src/protocol/mux/converter.rs b/lib/src/protocol/mux/converter.rs index 35b8b2c7a..6a6579806 100644 --- a/lib/src/protocol/mux/converter.rs +++ b/lib/src/protocol/mux/converter.rs @@ -1,13 +1,17 @@ use std::str::from_utf8_unchecked; -use kawa::{AsBuffer, Block, BlockConverter, Chunk, Flags, Kawa, Pair, StatusLine, Store}; - -use crate::protocol::http::parser::compare_no_case; +use kawa::{ + AsBuffer, Block, BlockConverter, Chunk, Flags, Kawa, Pair, ParsingErrorKind, ParsingPhase, + StatusLine, Store, +}; -use super::{ - parser::{FrameHeader, FrameType, H2Error}, - serializer::{gen_frame_header, gen_rst_stream}, - StreamId, +use crate::protocol::{ + http::parser::compare_no_case, + mux::{ + parser::{str_to_error_code, FrameHeader, FrameType, H2Error}, + serializer::{gen_frame_header, gen_rst_stream}, + StreamId, + }, }; pub struct H2BlockConverter<'a> { @@ -17,6 +21,26 @@ pub struct H2BlockConverter<'a> { } impl<'a, T: AsBuffer> BlockConverter for H2BlockConverter<'a> { + fn initialize(&mut self, kawa: &mut Kawa) { + // This is very ugly... we may add a h2 variant in kawa::ParsingErrorKind + match kawa.parsing_phase { + ParsingPhase::Error { + kind: ParsingErrorKind::Processing { message }, + .. + } => { + let error = str_to_error_code(message); + let mut frame = [0; 13]; + gen_rst_stream(&mut frame, self.stream_id, error).unwrap(); + kawa.push_out(Store::from_slice(&frame)); + } + ParsingPhase::Error { .. } => { + let mut frame = [0; 13]; + gen_rst_stream(&mut frame, self.stream_id, H2Error::InternalError).unwrap(); + kawa.push_out(Store::from_slice(&frame)); + } + _ => {} + } + } fn call(&mut self, block: Block, kawa: &mut Kawa) { let buffer = kawa.storage.buffer(); match block { @@ -140,24 +164,18 @@ impl<'a, T: AsBuffer> BlockConverter for H2BlockConverter<'a> { kawa.push_out(Store::from_slice(&header)); kawa.push_out(Store::Alloc(payload.into_boxed_slice(), 0)); } else if end_stream { - if kawa.is_error() { - let mut frame = [0; 13]; - gen_rst_stream(&mut frame, self.stream_id, H2Error::InternalError).unwrap(); - kawa.push_out(Store::from_slice(&frame)); - } else { - let mut header = [0; 9]; - gen_frame_header( - &mut header, - &FrameHeader { - payload_len: 0, - frame_type: FrameType::Data, - flags: 1, - stream_id: self.stream_id, - }, - ) - .unwrap(); - kawa.push_out(Store::from_slice(&header)); - } + let mut header = [0; 9]; + gen_frame_header( + &mut header, + &FrameHeader { + payload_len: 0, + frame_type: FrameType::Data, + flags: 1, + stream_id: self.stream_id, + }, + ) + .unwrap(); + kawa.push_out(Store::from_slice(&header)); } if end_header || end_stream { kawa.push_delimiter() diff --git a/lib/src/protocol/mux/h1.rs b/lib/src/protocol/mux/h1.rs index 3fb4bdff6..81d1637e9 100644 --- a/lib/src/protocol/mux/h1.rs +++ b/lib/src/protocol/mux/h1.rs @@ -3,9 +3,7 @@ use sozu_command::ready::Ready; use crate::{ println_, protocol::mux::{ - debug_kawa, forcefully_terminate_answer, set_default_answer, update_readiness_after_read, - update_readiness_after_write, BackendStatus, Context, Endpoint, GlobalStreamId, MuxResult, - Position, StreamState, + debug_kawa, forcefully_terminate_answer, parser::H2Error, set_default_answer, update_readiness_after_read, update_readiness_after_write, BackendStatus, Context, Endpoint, GlobalStreamId, MuxResult, Position, StreamState }, socket::SocketHandler, timer::TimeoutContainer, @@ -226,6 +224,9 @@ impl ConnectionH1 { Position::Client(BackendStatus::Connected(cluster_id)) | Position::Client(BackendStatus::Connecting(cluster_id)) => { self.stream = usize::MAX; + // keep alive should probably be used only if the http context is fully reset + // in case end_stream occurs due to an error the connection state is probably + // unrecoverable and should be terminated if stream_context.keep_alive_backend { self.position = Position::Client(BackendStatus::KeepAlive(std::mem::take(cluster_id))) @@ -241,7 +242,7 @@ impl ConnectionH1 { // if the answer is not terminated we send an RstStream to properly clean the stream // if it is terminated, we finish the transfer, the backend is not necessary anymore if !stream.back.is_terminated() { - forcefully_terminate_answer(stream, &mut self.readiness); + forcefully_terminate_answer(stream, &mut self.readiness, H2Error::InternalError); } else { stream.state = StreamState::Unlinked; self.readiness.interest.insert(Ready::WRITABLE); diff --git a/lib/src/protocol/mux/h2.rs b/lib/src/protocol/mux/h2.rs index 91647d9bb..5000b5294 100644 --- a/lib/src/protocol/mux/h2.rs +++ b/lib/src/protocol/mux/h2.rs @@ -9,7 +9,7 @@ use crate::{ converter, debug_kawa, forcefully_terminate_answer, parser::{ self, error_code_to_str, Frame, FrameHeader, FrameType, H2Error, Headers, ParserError, - ParserErrorKind, + ParserErrorKind, StreamDependency, WindowUpdate, }, pkawa, serializer, set_default_answer, update_readiness_after_read, update_readiness_after_write, BackendStatus, Context, Endpoint, GenericHttpStream, @@ -84,8 +84,25 @@ impl Prioriser { pub fn new() -> Self { Self {} } - pub fn push_priority(&mut self, stream_id: StreamId, priority: parser::PriorityPart) { - println!("PRIORITY REQUEST FOR {stream_id}: {priority:?}"); + pub fn push_priority(&mut self, stream_id: StreamId, priority: parser::PriorityPart) -> bool { + println_!("PRIORITY REQUEST FOR {stream_id}: {priority:?}"); + match priority { + parser::PriorityPart::Rfc7540 { + stream_dependency, + weight, + } => { + if stream_dependency.stream_id == stream_id { + println_!("STREAM CAN'T DEPEND ON ITSELF"); + true + } else { + false + } + } + parser::PriorityPart::Rfc9218 { + urgency, + incremental, + } => false, + } } } @@ -289,13 +306,18 @@ impl ConnectionH2 { let read_stream = if stream_id == 0 { H2StreamId::Zero } else if let Some(global_stream_id) = self.streams.get(&stream_id) { + let allowed_on_half_closed = header.frame_type + == FrameType::WindowUpdate + || header.frame_type == FrameType::Priority; let stream = &context.streams[*global_stream_id]; println_!( "REQUESTING EXISTING STREAM {stream_id}: {}/{:?}", stream.received_end_of_stream, stream.state ); - if stream.received_end_of_stream || !stream.state.is_open() { + if !allowed_on_half_closed + && (stream.received_end_of_stream || !stream.state.is_open()) + { return self.goaway(H2Error::StreamClosed); } if header.frame_type == FrameType::Data { @@ -318,7 +340,9 @@ impl ConnectionH2 { Some(_) => {} None => return self.goaway(H2Error::InternalError), } - } else if header.frame_type != FrameType::Priority { + } else if header.frame_type != FrameType::Priority + && header.frame_type != FrameType::WindowUpdate + { println_!( "ONLY HEADERS AND PRIORITY CAN BE RECEIVED ON IDLE/CLOSED STREAMS" ); @@ -555,7 +579,7 @@ impl ConnectionH2 { Position::Client(_) => {} Position::Server => { // mark stream as reusable - println_!("Recycle stream: {global_stream_id}"); + println_!("Recycle1 stream: {global_stream_id}"); // ACCESS LOG stream.generate_access_log( false, @@ -683,12 +707,25 @@ impl ConnectionH2 { // can this fail? let stream_id = headers.stream_id; let global_stream_id = *self.streams.get(&stream_id).unwrap(); + + if let Some(priority) = &headers.priority { + if self.prioriser.push_priority(stream_id, priority.clone()) { + self.reset_stream( + global_stream_id, + context, + endpoint, + H2Error::ProtocolError, + ); + return MuxResult::Continue; + } + } + let kawa = &mut self.zero; let buffer = headers.header_block_fragment.data(kawa.storage.buffer()); let stream = &mut context.streams[global_stream_id]; let parts = &mut stream.split(&self.position); let was_initial = parts.rbuffer.is_initial(); - pkawa::handle_header( + let status = pkawa::handle_header( &mut self.decoder, &mut self.prioriser, stream_id, @@ -698,8 +735,12 @@ impl ConnectionH2 { parts.context, ); kawa.storage.clear(); - if parts.rbuffer.is_error() { - return self.goaway(H2Error::CompressionError); + if let Err((error, global)) = status { + if global { + return self.goaway(error); + } else { + return self.reset_stream(global_stream_id, context, endpoint, error); + } } debug_kawa(parts.rbuffer); stream.received_end_of_stream |= headers.end_stream; @@ -729,9 +770,23 @@ impl ConnectionH2 { return self.goaway(H2Error::ProtocolError); } }, - Frame::Priority(priority) => self - .prioriser - .push_priority(priority.stream_id, priority.inner), + Frame::Priority(priority) => { + if self + .prioriser + .push_priority(priority.stream_id, priority.inner) + { + if let Some(global_stream_id) = self.streams.get(&priority.stream_id) { + return self.reset_stream( + *global_stream_id, + context, + endpoint, + H2Error::ProtocolError, + ); + } else { + return self.goaway(H2Error::ProtocolError); + } + } + } Frame::RstStream(rst_stream) => { println_!( "RstStream({} -> {})", @@ -766,18 +821,23 @@ impl ConnectionH2 { return MuxResult::Continue; } for setting in settings.settings { + let v = setting.value; + let mut is_error = false; #[rustfmt::skip] let _ = match setting.identifier { - 1 => self.peer_settings.settings_header_table_size = setting.value, - 2 => self.peer_settings.settings_enable_push = setting.value == 1, - 3 => self.peer_settings.settings_max_concurrent_streams = setting.value, - 4 => self.peer_settings.settings_initial_window_size = setting.value, - 5 => self.peer_settings.settings_max_frame_size = setting.value, - 6 => self.peer_settings.settings_max_header_list_size = setting.value, - 8 => self.peer_settings.settings_enable_connect_protocol = setting.value == 1, - 9 => self.peer_settings.settings_no_rfc7540_priorities = setting.value == 1, + 1 => { self.peer_settings.settings_header_table_size = v }, + 2 => { self.peer_settings.settings_enable_push = v == 1; is_error |= v > 1 }, + 3 => { self.peer_settings.settings_max_concurrent_streams = v }, + 4 => { self.peer_settings.settings_initial_window_size = v; is_error |= v >= 1<<31 }, + 5 => { self.peer_settings.settings_max_frame_size = v; is_error |= v >= 1<<24 || v < 1<<14 }, + 6 => { self.peer_settings.settings_max_header_list_size = v }, + 8 => { self.peer_settings.settings_enable_connect_protocol = v == 1; is_error |= v > 1 }, + 9 => { self.peer_settings.settings_no_rfc7540_priorities = v == 1; is_error |= v > 1 }, other => println!("unknown setting_id: {other}, we MUST ignore this"), }; + if is_error { + return self.goaway(H2Error::ProtocolError); + } } println_!("{:#?}", self.peer_settings); @@ -792,6 +852,9 @@ impl ConnectionH2 { self.expect_write = Some(H2StreamId::Zero); } Frame::Ping(ping) => { + if ping.ack { + return MuxResult::Continue; + } let kawa = &mut self.zero; match serializer::gen_ping_acknolegment(kawa.storage.space(), &ping.payload) { Ok((_, size)) => kawa.storage.fill(size), @@ -812,21 +875,36 @@ impl ConnectionH2 { ); // return self.goaway(H2Error::NoError); } - Frame::WindowUpdate(update) => { - let window = if update.stream_id == 0 { - &mut self.window + Frame::WindowUpdate(WindowUpdate { + stream_id, + increment, + }) => { + let increment = increment as i32; + if stream_id == 0 { + if increment > i32::MAX - self.window { + return self.goaway(H2Error::FlowControlError); + } else { + self.window += increment; + } } else { - if let Some(global_stream_id) = self.streams.get(&update.stream_id) { - &mut context.streams[*global_stream_id].window + if let Some(global_stream_id) = self.streams.get(&stream_id) { + let stream = &mut context.streams[*global_stream_id]; + if increment > i32::MAX - stream.window { + return self.reset_stream( + *global_stream_id, + context, + endpoint, + H2Error::FlowControlError, + ); + } else { + stream.window += increment; + } } else { - unreachable!() + println_!( + "Ignoring window update on closed stream {stream_id}: {increment}" + ); } }; - if update.increment as i32 > i32::MAX - *window { - return self.goaway(H2Error::FlowControlError); - } else { - *window += update.increment as i32; - } } Frame::Continuation(_) => unreachable!(), } @@ -867,6 +945,27 @@ impl ConnectionH2 { } } + pub fn reset_stream( + &mut self, + stream_id: GlobalStreamId, + context: &mut Context, + mut endpoint: E, + error: H2Error, + ) -> MuxResult + where + E: Endpoint, + L: ListenerHandler + L7ListenerHandler, + { + let stream = &mut context.streams[stream_id]; + println_!("reset H2 stream {stream_id}: {:#?}", stream.context); + let old_state = std::mem::replace(&mut stream.state, StreamState::Unlinked); + forcefully_terminate_answer(stream, &mut self.readiness, error); + if let StreamState::Linked(token) = old_state { + endpoint.end_stream(token, stream_id, context); + } + MuxResult::Continue + } + pub fn end_stream(&mut self, stream: GlobalStreamId, context: &mut Context) where L: ListenerHandler + L7ListenerHandler, @@ -878,6 +977,9 @@ impl ConnectionH2 { for (stream_id, global_stream_id) in &self.streams { if *global_stream_id == stream { let id = *stream_id; + // if the stream is not in a closed state we should probably send an + // RST_STREAM frame here. We also need to handle frames coming from + // the backend on this stream after it was closed self.streams.remove(&id); return; } @@ -893,7 +995,11 @@ impl ConnectionH2 { // if the answer is not terminated we send an RstStream to properly clean the stream // if it is terminated, we finish the transfer, the backend is not necessary anymore if !stream.back.is_terminated() { - forcefully_terminate_answer(stream, &mut self.readiness); + forcefully_terminate_answer( + stream, + &mut self.readiness, + H2Error::InternalError, + ); } else { stream.state = StreamState::Unlinked; self.readiness.interest.insert(Ready::WRITABLE); diff --git a/lib/src/protocol/mux/mod.rs b/lib/src/protocol/mux/mod.rs index 4fdde0636..8acbd4f20 100644 --- a/lib/src/protocol/mux/mod.rs +++ b/lib/src/protocol/mux/mod.rs @@ -37,7 +37,11 @@ use crate::{ StateResult, }; -pub use crate::protocol::mux::{h1::ConnectionH1, h2::ConnectionH2}; +pub use crate::protocol::mux::{ + h1::ConnectionH1, + h2::ConnectionH2, + parser::{error_code_to_str, H2Error}, +}; #[macro_export] macro_rules! println_ { @@ -127,15 +131,18 @@ fn set_default_answer(stream: &mut Stream, readiness: &mut Readiness, code: u16) /// Forcefully terminates a kawa message by setting the "end_stream" flag and setting the parsing_phase to Error. /// An H2 converter will produce an RstStream frame. -fn forcefully_terminate_answer(stream: &mut Stream, readiness: &mut Readiness) { +fn forcefully_terminate_answer(stream: &mut Stream, readiness: &mut Readiness, error: H2Error) { let kawa = &mut stream.back; - kawa.push_block(kawa::Block::Flags(kawa::Flags { - end_body: false, - end_chunk: false, - end_header: false, - end_stream: true, - })); - kawa.parsing_phase.error("Termination".into()); + kawa.out.clear(); + kawa.blocks.clear(); + // kawa.push_block(kawa::Block::Flags(kawa::Flags { + // end_body: false, + // end_chunk: false, + // end_header: false, + // end_stream: true, + // })); + kawa.parsing_phase + .error(error_code_to_str(error as u32).into()); debug_kawa(kawa); stream.state = StreamState::Unlinked; readiness.interest.insert(Ready::WRITABLE); @@ -1310,7 +1317,11 @@ impl &'static str { } } +pub fn str_to_error_code(str: &str) -> H2Error { + match str { + "NO_ERROR" => H2Error::NoError, + "PROTOCOL_ERROR" => H2Error::ProtocolError, + "INTERNAL_ERROR" => H2Error::InternalError, + "FLOW_CONTROL_ERROR" => H2Error::FlowControlError, + "SETTINGS_TIMEOUT" => H2Error::SettingsTimeout, + "STREAM_CLOSED" => H2Error::StreamClosed, + "FRAME_SIZE_ERROR" => H2Error::FrameSizeError, + "REFUSED_STREAM" => H2Error::RefusedStream, + "CANCEL" => H2Error::Cancel, + "COMPRESSION_ERROR" => H2Error::CompressionError, + "CONNECT_ERROR" => H2Error::ConnectError, + "ENHANCE_YOUR_CALM" => H2Error::EnhanceYourCalm, + "INADEQUATE_SECURITY" => H2Error::InadequateSecurity, + "HTTP_1_1_REQUIRED" => H2Error::HTTP11Required, + _ => H2Error::InternalError, + } +} + #[derive(Clone, Debug, PartialEq)] pub struct ParserError<'a> { pub input: &'a [u8], @@ -566,15 +586,19 @@ pub fn push_promise_frame<'a>( #[derive(Clone, Debug, PartialEq)] pub struct Ping { pub payload: [u8; 8], + pub ack: bool, } pub fn ping_frame<'a>( input: &'a [u8], - _header: &FrameHeader, + header: &FrameHeader, ) -> IResult<&'a [u8], Frame, ParserError<'a>> { let (i, data) = take(8usize)(input)?; - let mut p = Ping { payload: [0; 8] }; + let mut p = Ping { + payload: [0; 8], + ack: header.flags & 1 != 0, + }; p.payload[..8].copy_from_slice(&data[..8]); Ok((i, Frame::Ping(p))) diff --git a/lib/src/protocol/mux/pkawa.rs b/lib/src/protocol/mux/pkawa.rs index 5846fa36f..012c16ddb 100644 --- a/lib/src/protocol/mux/pkawa.rs +++ b/lib/src/protocol/mux/pkawa.rs @@ -1,4 +1,4 @@ -use std::{io::Write, str::from_utf8_unchecked}; +use std::{io::Write, str::from_utf8}; use kawa::{ h1::ParserCallbacks, repr::Slice, Block, BodySize, Flags, Kind, Pair, ParsingPhase, StatusLine, @@ -9,10 +9,28 @@ use crate::{ pool::Checkout, protocol::{ http::parser::compare_no_case, - mux::{h2::Prioriser, parser::PriorityPart, GenericHttpStream, StreamId}, + mux::{ + h2::Prioriser, + parser::{H2Error, PriorityPart}, + GenericHttpStream, StreamId, + }, }, }; +trait AdHocStore { + fn len(&self) -> usize; +} +impl AdHocStore for Store { + fn len(&self) -> usize { + match self { + Store::Empty => 0, + Store::Slice(slice) | Store::Detached(slice) => slice.len(), + Store::Static(s) => s.len(), + Store::Alloc(a, i) => a.len() - *i as usize, + } + } +} + pub fn handle_header( decoder: &mut hpack::Decoder, prioriser: &mut Prioriser, @@ -21,7 +39,8 @@ pub fn handle_header( input: &[u8], end_stream: bool, callbacks: &mut C, -) where +) -> Result<(), (H2Error, bool)> +where C: ParserCallbacks, { if !kawa.is_initial() { @@ -34,6 +53,8 @@ pub fn handle_header( let mut authority = Store::Empty; let mut path = Store::Empty; let mut scheme = Store::Empty; + let mut invalid_headers = false; + let mut regular_headers = false; let decode_status = decoder.decode_with_cb(input, |k, v| { let start = kawa.storage.end as u32; kawa.storage.write_all(&v).unwrap(); @@ -45,28 +66,49 @@ pub fn handle_header( }); if compare_no_case(&k, b":method") { + if !method.is_empty() || regular_headers { + invalid_headers = true; + } method = val; } else if compare_no_case(&k, b":scheme") { + if !scheme.is_empty() || regular_headers { + invalid_headers = true; + } scheme = val; } else if compare_no_case(&k, b":path") { + if !path.is_empty() || regular_headers { + invalid_headers = true; + } path = val; } else if compare_no_case(&k, b":authority") { + if !authority.is_empty() || regular_headers { + invalid_headers = true; + } authority = val; + } else if k.starts_with(b":") { + invalid_headers = true; } else if compare_no_case(&k, b"cookie---") { + regular_headers = true; todo!("cookies should be split in pairs"); - } else if compare_no_case(&k, b"priority") { - todo!("decode priority"); - prioriser.push_priority( - stream_id, - PriorityPart::Rfc9218 { - urgency: todo!(), - incremental: todo!(), - }, - ) } else { + regular_headers = true; if compare_no_case(&k, b"content-length") { - let length = unsafe { from_utf8_unchecked(&v).parse::().unwrap() }; - kawa.body_size = BodySize::Length(length); + if let Some(length) = + from_utf8(&v).ok().and_then(|v| v.parse::().ok()) + { + kawa.body_size = BodySize::Length(length); + } else { + invalid_headers = true; + } + } else if compare_no_case(&k, b"priority") { + todo!("decode priority"); + prioriser.push_priority( + stream_id, + PriorityPart::Rfc9218 { + urgency: todo!(), + incremental: todo!(), + }, + ); } kawa.storage.write_all(&k).unwrap(); let key = Store::Slice(Slice { @@ -78,11 +120,16 @@ pub fn handle_header( }); if let Err(error) = decode_status { println!("INVALID FRAGMENT: {error:?}"); - kawa.parsing_phase.error("Invalid header fragment".into()); + return Err((H2Error::CompressionError, true)); } - if method.is_empty() || authority.is_empty() || path.is_empty() { - println!("MISSING PSEUDO HEADERS"); - kawa.parsing_phase.error("Missing pseudo headers".into()); + if invalid_headers + || method.len() == 0 + || authority.len() == 0 + || path.len() == 0 + || scheme.len() == 0 + { + println!("INVALID HEADERS"); + return Err((H2Error::ProtocolError, false)); } // uri is only used by H1 statusline, in most cases it only consists of the path // a better algorithm should be used though @@ -107,6 +154,8 @@ pub fn handle_header( Kind::Response => { let mut code = 0; let mut status = Store::Empty; + let mut invalid_headers = false; + let mut regular_headers = false; let decode_status = decoder.decode_with_cb(input, |k, v| { let start = kawa.storage.end as u32; kawa.storage.write_all(&v).unwrap(); @@ -118,9 +167,21 @@ pub fn handle_header( }); if compare_no_case(&k, b":status") { + if !status.is_empty() || regular_headers { + invalid_headers = true; + } status = val; - code = unsafe { from_utf8_unchecked(&v).parse::().ok().unwrap() } + if let Some(parsed_code) = + from_utf8(&v).ok().and_then(|v| v.parse::().ok()) + { + code = parsed_code; + } else { + invalid_headers = true; + } + } else if k.starts_with(b":") { + invalid_headers = true; } else { + regular_headers = true; kawa.storage.write_all(&k).unwrap(); let key = Store::Slice(Slice { start: start + len_val, @@ -131,11 +192,11 @@ pub fn handle_header( }); if let Err(error) = decode_status { println!("INVALID FRAGMENT: {error:?}"); - kawa.parsing_phase.error("Invalid header fragment".into()); + return Err((H2Error::CompressionError, true)); } - if status.is_empty() { - println!("MISSING PSEUDO HEADERS"); - kawa.parsing_phase.error("Missing pseudo headers".into()); + if invalid_headers || status.len() == 0 { + println!("INVALID HEADERS"); + return Err((H2Error::ProtocolError, false)); } StatusLine::Response { version: Version::V20, @@ -153,9 +214,6 @@ pub fn handle_header( kawa.storage.start, kawa.storage.head, kawa.storage.end ); - if kawa.is_error() { - return; - } callbacks.on_headers(kawa); if end_stream { @@ -176,7 +234,7 @@ pub fn handle_header( })); if kawa.parsing_phase == ParsingPhase::Terminated { - return; + return Ok(()); } kawa.parsing_phase = match kawa.body_size { @@ -185,6 +243,7 @@ pub fn handle_header( BodySize::Length(_) => ParsingPhase::Body, BodySize::Empty => ParsingPhase::Chunks { first: true }, }; + Ok(()) } pub fn handle_trailer( @@ -192,32 +251,38 @@ pub fn handle_trailer( input: &[u8], end_stream: bool, decoder: &mut hpack::Decoder, -) { - decoder - .decode_with_cb(input, |k, v| { - let start = kawa.storage.end as u32; - kawa.storage.write_all(&k).unwrap(); - kawa.storage.write_all(&v).unwrap(); - let len_key = k.len() as u32; - let len_val = v.len() as u32; - let key = Store::Slice(Slice { - start, - len: len_key, - }); - let val = Store::Slice(Slice { - start: start + len_key, - len: len_val, - }); - kawa.push_block(Block::Header(Pair { key, val })); - }) - .unwrap(); +) -> Result<(), (H2Error, bool)> { + if !end_stream { + return Err((H2Error::ProtocolError, false)); + } + let decode_status = decoder.decode_with_cb(input, |k, v| { + let start = kawa.storage.end as u32; + kawa.storage.write_all(&k).unwrap(); + kawa.storage.write_all(&v).unwrap(); + let len_key = k.len() as u32; + let len_val = v.len() as u32; + let key = Store::Slice(Slice { + start, + len: len_key, + }); + let val = Store::Slice(Slice { + start: start + len_key, + len: len_val, + }); + kawa.push_block(Block::Header(Pair { key, val })); + }); + + if let Err(error) = decode_status { + println!("INVALID FRAGMENT: {error:?}"); + return Err((H2Error::CompressionError, true)); + } - // assert!(end_stream); kawa.push_block(Block::Flags(Flags { - end_body: end_stream, + end_body: false, end_chunk: false, end_header: true, - end_stream, + end_stream: true, })); kawa.parsing_phase = ParsingPhase::Terminated; + Ok(()) } diff --git a/lib/src/protocol/mux/serializer.rs b/lib/src/protocol/mux/serializer.rs index e7eb39127..18c53720d 100644 --- a/lib/src/protocol/mux/serializer.rs +++ b/lib/src/protocol/mux/serializer.rs @@ -6,7 +6,7 @@ use cookie_factory::{ GenError, }; -use super::{ +use crate::protocol::mux::{ h2::H2Settings, parser::{FrameHeader, FrameType, H2Error}, };