diff --git a/lib/src/protocol/mux/h1.rs b/lib/src/protocol/mux/h1.rs index d6d05f53d..24bf2abb2 100644 --- a/lib/src/protocol/mux/h1.rs +++ b/lib/src/protocol/mux/h1.rs @@ -1,7 +1,7 @@ use sozu_command::ready::Ready; use crate::{ - protocol::mux::{Context, GlobalStreamId, Position}, + protocol::mux::{Context, GlobalStreamId, MuxResult, Position}, socket::{SocketHandler, SocketResult}, Readiness, }; @@ -15,7 +15,7 @@ pub struct ConnectionH1 { } impl ConnectionH1 { - pub fn readable(&mut self, context: &mut Context) { + pub fn readable(&mut self, context: &mut Context) -> MuxResult { println!("======= MUX H1 READABLE"); let stream = &mut context.streams.get(self.stream); let kawa = match self.position { @@ -40,8 +40,9 @@ impl ConnectionH1 { if kawa.is_terminated() { self.readiness.interest.remove(Ready::READABLE); } + MuxResult::Continue } - pub fn writable(&mut self, context: &mut Context) { + pub fn writable(&mut self, context: &mut Context) -> MuxResult { println!("======= MUX H1 WRITABLE"); let stream = &mut context.streams.get(self.stream); let kawa = match self.position { @@ -52,7 +53,7 @@ impl ConnectionH1 { let bufs = kawa.as_io_slice(); if bufs.is_empty() { self.readiness.interest.remove(Ready::WRITABLE); - return; + return MuxResult::Continue; } let (size, status) = self.socket.socket_write_vectored(&bufs); println!(" size: {size}, status: {status:?}"); @@ -62,5 +63,6 @@ impl ConnectionH1 { } else { self.readiness.event.remove(Ready::WRITABLE); } + MuxResult::Continue } } diff --git a/lib/src/protocol/mux/h2.rs b/lib/src/protocol/mux/h2.rs index 4322cc336..6a107c016 100644 --- a/lib/src/protocol/mux/h2.rs +++ b/lib/src/protocol/mux/h2.rs @@ -6,8 +6,8 @@ use sozu_command::ready::Ready; use crate::{ protocol::mux::{ - parser::{self, error_code_to_str, FrameHeader}, - pkawa, serializer, Context, GlobalStreamId, Position, StreamId, + parser::{self, error_code_to_str, Frame, FrameHeader, FrameType}, + pkawa, serializer, Context, GlobalStreamId, MuxResult, Position, StreamId, }, socket::SocketHandler, Readiness, @@ -58,7 +58,7 @@ pub struct ConnectionH2 { } impl ConnectionH2 { - pub fn readable(&mut self, context: &mut Context) { + pub fn readable(&mut self, context: &mut Context) -> MuxResult { println!("======= MUX H2 READABLE"); let kawa = if let Some((stream_id, amount)) = self.expect { let kawa = context.streams.get(stream_id).front(self.position); @@ -70,16 +70,16 @@ impl ConnectionH2 { self.expect = None; } else { self.expect = Some((stream_id, amount - size)); - return; + return MuxResult::Continue; } } else { self.readiness.event.remove(Ready::READABLE); - return; + return MuxResult::Continue; } kawa } else { self.readiness.event.remove(Ready::READABLE); - return; + return MuxResult::Continue; }; match (&self.state, &self.position) { (H2State::ClientPreface, Position::Client) => { @@ -94,9 +94,9 @@ impl ConnectionH2 { match parser::frame_header(i) { Ok(( _, - parser::FrameHeader { + FrameHeader { payload_len, - frame_type: parser::FrameType::Settings, + frame_type: FrameType::Settings, flags: 0, stream_id: 0, }, @@ -121,9 +121,9 @@ impl ConnectionH2 { self.state = H2State::ServerSettings; match serializer::gen_frame_header( kawa.storage.space(), - &parser::FrameHeader { + &FrameHeader { payload_len: 6 * 2, - frame_type: parser::FrameType::Settings, + frame_type: FrameType::Settings, flags: 0, stream_id: 0, }, @@ -136,9 +136,9 @@ impl ConnectionH2 { // .unwrap(); match serializer::gen_frame_header( kawa.storage.space(), - &parser::FrameHeader { + &FrameHeader { payload_len: 0, - frame_type: parser::FrameType::Settings, + frame_type: FrameType::Settings, flags: 1, stream_id: 0, }, @@ -166,7 +166,7 @@ impl ConnectionH2 { } else { self.create_stream(header.stream_id, context) }; - let stream_id = if header.frame_type == parser::FrameType::Data { + let stream_id = if header.frame_type == FrameType::Data { stream_id } else { 0 @@ -181,21 +181,23 @@ impl ConnectionH2 { (H2State::Frame(header), Position::Server) => { let i = kawa.storage.data(); println!(" data: {i:?}"); - match parser::frame_body(i, header, self.settings.settings_max_frame_size) { - Ok((_, frame)) => { - kawa.storage.clear(); - self.handle(frame, context); - } - Err(e) => panic!("{e:?}"), - } + let frame = + match parser::frame_body(i, header, self.settings.settings_max_frame_size) { + Ok((_, frame)) => frame, + Err(e) => panic!("{e:?}"), + }; + kawa.storage.clear(); + let state_result = self.handle(frame, context); self.state = H2State::Header; self.expect = Some((0, 9)); + return state_result; } _ => unreachable!(), } + MuxResult::Continue } - pub fn writable(&mut self, context: &mut Context) { + pub fn writable(&mut self, context: &mut Context) -> MuxResult { println!("======= MUX H2 WRITABLE"); match (&self.state, &self.position) { (H2State::ClientPreface, Position::Client) => todo!("Send PRI + client Settings"), @@ -214,6 +216,7 @@ impl ConnectionH2 { self.state = H2State::Header; self.expect = Some((0, 9)); } + MuxResult::Continue } _ => unreachable!(), } @@ -229,25 +232,43 @@ impl ConnectionH2 { } } - fn handle(&mut self, frame: parser::Frame, context: &mut Context) { + fn handle(&mut self, frame: Frame, context: &mut Context) -> MuxResult { println!("{frame:?}"); match frame { - parser::Frame::Data(_) => todo!(), - parser::Frame::Headers(headers) => { - // if !headers.end_headers { - // self.state = H2State::Continuation - // } - let global_stream_id = self.streams.get(&headers.stream_id).unwrap(); + Frame::Data(_) => todo!(), + Frame::Headers(headers) => { + if !headers.end_headers { + todo!(); + // self.state = H2State::Continuation + } + let global_stream_id = *self.streams.get(&headers.stream_id).unwrap(); let kawa = context.streams.zero.front(self.position); let buffer = headers.header_block_fragment.data(kawa.storage.buffer()); - let stream = &mut context.streams.others[*global_stream_id - 1]; + let stream = &mut context.streams.others[global_stream_id - 1]; let kawa = &mut stream.front; pkawa::handle_header(kawa, buffer, &mut context.decoder); stream.context.on_headers(kawa); + return MuxResult::Connect(global_stream_id); } - parser::Frame::Priority(priority) => (), - parser::Frame::RstStream(_) => todo!(), - parser::Frame::Settings(settings) => { + Frame::PushPromise(push_promise) => match self.position { + Position::Client => { + todo!("if enabled forward the push") + } + Position::Server => { + println!("A client should not push promises"); + return MuxResult::CloseSession; + } + }, + Frame::Priority(priority) => (), + Frame::RstStream(rst_stream) => { + println!( + "RstStream({} -> {})", + rst_stream.error_code, + error_code_to_str(rst_stream.error_code) + ); + // context.streams.get(priority.stream_id).close() + } + Frame::Settings(settings) => { for setting in settings.settings { match setting.identifier { 1 => self.settings.settings_header_table_size = setting.value, @@ -261,14 +282,21 @@ impl ConnectionH2 { } println!("{:#?}", self.settings); } - parser::Frame::PushPromise(_) => todo!(), - parser::Frame::Ping(_) => todo!(), - parser::Frame::GoAway(goaway) => panic!("{}", error_code_to_str(goaway.error_code)), - parser::Frame::WindowUpdate(update) => { + Frame::Ping(_) => todo!(), + Frame::GoAway(goaway) => { + println!( + "GoAway({} -> {})", + goaway.error_code, + error_code_to_str(goaway.error_code) + ); + return MuxResult::CloseSession; + } + Frame::WindowUpdate(update) => { let global_stream_id = *self.streams.get(&update.stream_id).unwrap(); context.streams.get(global_stream_id).window += update.increment as i32; } - parser::Frame::Continuation(_) => todo!(), + Frame::Continuation(_) => todo!(), } + MuxResult::Continue } } diff --git a/lib/src/protocol/mux/mod.rs b/lib/src/protocol/mux/mod.rs index 6b4f63fe5..906445459 100644 --- a/lib/src/protocol/mux/mod.rs +++ b/lib/src/protocol/mux/mod.rs @@ -28,7 +28,7 @@ use crate::{ AcceptError, L7Proxy, ProxySession, Readiness, SessionMetrics, SessionResult, StateResult, }; -use self::h2::{H2State, H2Settings}; +use self::h2::{H2Settings, H2State}; /// Generic Http representation using the Kawa crate using the Checkout of Sozu as buffer type GenericHttpStream = kawa::Kawa; @@ -41,6 +41,13 @@ pub enum Position { Server, } +pub enum MuxResult { + Continue, + CloseSession, + Close(GlobalStreamId), + Connect(GlobalStreamId), +} + pub enum Connection { H1(ConnectionH1), H2(ConnectionH2), @@ -111,13 +118,13 @@ impl Connection { Connection::H2(c) => &mut c.readiness, } } - fn readable(&mut self, context: &mut Context) { + fn readable(&mut self, context: &mut Context) -> MuxResult { match self { Connection::H1(c) => c.readable(context), Connection::H2(c) => c.readable(context), } } - fn writable(&mut self, context: &mut Context) { + fn writable(&mut self, context: &mut Context) -> MuxResult { match self { Connection::H1(c) => c.writable(context), Connection::H2(c) => c.writable(context), @@ -128,8 +135,8 @@ impl Connection { pub struct Stream { // pub request_id: Ulid, pub window: i32, - pub front: GenericHttpStream, - pub back: GenericHttpStream, + front: GenericHttpStream, + back: GenericHttpStream, pub context: HttpContext, } @@ -276,24 +283,44 @@ impl SessionState for Mux { let mut dirty = false; if self.frontend.readiness().filter_interest().is_readable() { - self.frontend.readable(context); + match self.frontend.readable(context) { + MuxResult::Continue => (), + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Close(_) => todo!(), + MuxResult::Connect(_) => todo!(), + } dirty = true; } for (_, backend) in self.backends.iter_mut() { if backend.readiness().filter_interest().is_writable() { - backend.writable(context); + match backend.writable(context) { + MuxResult::Continue => (), + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Close(_) => todo!(), + MuxResult::Connect(_) => unreachable!(), + } dirty = true; } if backend.readiness().filter_interest().is_readable() { - backend.readable(context); + match backend.readable(context) { + MuxResult::Continue => (), + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Close(_) => todo!(), + MuxResult::Connect(_) => unreachable!(), + } dirty = true; } } if self.frontend.readiness().filter_interest().is_writable() { - self.frontend.writable(context); + match self.frontend.writable(context) { + MuxResult::Continue => (), + MuxResult::CloseSession => return SessionResult::Close, + MuxResult::Close(_) => todo!(), + MuxResult::Connect(_) => unreachable!(), + } dirty = true; }