diff --git a/lib/src/https.rs b/lib/src/https.rs index 7195af71d..3eaf5c197 100644 --- a/lib/src/https.rs +++ b/lib/src/https.rs @@ -54,6 +54,7 @@ use crate::{ answers::HttpAnswers, parser::{hostname_and_port, Method}, }, + mux::Mux, proxy_protocol::expect::ExpectProxyProtocol, rustls::TlsHandshake, Http, Pipe, SessionState, @@ -65,8 +66,8 @@ use crate::{ tls::{CertificateResolver, MutexWrappedCertificateResolver, ParsedCertificateAndKey}, util::UnwrapLog, AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, - ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, - SessionMetrics, SessionResult, StateMachineBuilder, StateResult, + ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, Readiness, + SessionIsToBeClosed, SessionMetrics, SessionResult, StateMachineBuilder, StateResult, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -86,6 +87,7 @@ StateMachineBuilder! { enum HttpsStateMachine impl SessionState { Expect(ExpectProxyProtocol, ServerConnection), Handshake(TlsHandshake), + Mux(Mux), Http(Http), WebSocket(Pipe), Http2(Http2) -> todo!("H2"), @@ -183,6 +185,7 @@ impl HttpsSession { HttpsStateMachine::Expect(expect, ssl) => self.upgrade_expect(expect, ssl), HttpsStateMachine::Handshake(handshake) => self.upgrade_handshake(handshake), HttpsStateMachine::Http(http) => self.upgrade_http(http), + HttpsStateMachine::Mux(mux) => unimplemented!(), HttpsStateMachine::Http2(_) => self.upgrade_http2(), HttpsStateMachine::WebSocket(wss) => self.upgrade_websocket(wss), HttpsStateMachine::FailedUpgrade(_) => unreachable!(), @@ -266,6 +269,7 @@ impl HttpsSession { // Some client don't fill in the ALPN protocol, in this case we default to Http/1.1 None => AlpnProtocol::Http11, }; + println!("ALPN: {alpn:?}"); if let Some(version) = handshake.session.protocol_version() { incr!(rustls_version_str(version)); @@ -280,6 +284,50 @@ impl HttpsSession { }; gauge_add!("protocol.tls.handshake", -1); + // return Some(HttpsStateMachine::Mux(Mux::new( + // self.frontend_token, + // handshake.request_id, + // self.listener.clone(), + // self.pool.clone(), + // self.public_address, + // self.peer_address, + // self.sticky_name.clone(), + // ))); + use crate::protocol::mux; + let frontend = match alpn { + AlpnProtocol::Http11 => mux::Connection::H1(mux::ConnectionH1 { + socket: front_stream, + position: mux::Position::Server, + readiness: Readiness { + interest: Ready::READABLE | Ready::HUP | Ready::ERROR, + event: handshake.frontend_readiness.event, + }, + stream: 0, + }), + AlpnProtocol::H2 => mux::Connection::H2(mux::ConnectionH2 { + socket: front_stream, + position: mux::Position::Server, + readiness: Readiness { + interest: Ready::READABLE | Ready::HUP | Ready::ERROR, + event: handshake.frontend_readiness.event, + }, + streams: HashMap::new(), + state: mux::H2State::ClientPreface, + }), + }; + let mut mux = Mux { + frontend_token: self.frontend_token, + frontend, + backends: HashMap::new(), + streams: Vec::new(), + listener: self.listener.clone(), + pool: self.pool.clone(), + public_address: self.public_address, + peer_address: self.peer_address, + sticky_name: self.sticky_name.clone(), + }; + mux.create_stream(handshake.request_id).ok()?; + return Some(HttpsStateMachine::Mux(mux)); match alpn { AlpnProtocol::Http11 => { let mut http = Http::new( @@ -408,6 +456,7 @@ impl ProxySession for HttpsSession { StateMarker::Expect => gauge_add!("protocol.proxy.expect", -1), StateMarker::Handshake => gauge_add!("protocol.tls.handshake", -1), StateMarker::Http => gauge_add!("protocol.https", -1), + StateMarker::Mux => gauge_add!("protocol.https", -1), StateMarker::WebSocket => { gauge_add!("protocol.wss", -1); gauge_add!("websocket.active_requests", -1); @@ -469,6 +518,7 @@ impl ProxySession for HttpsSession { } fn ready(&mut self, session: Rc>) -> SessionIsToBeClosed { + println!("READY"); self.metrics.service_start(); let session_result = diff --git a/lib/src/protocol/mod.rs b/lib/src/protocol/mod.rs index 1342173a8..88dde18cb 100644 --- a/lib/src/protocol/mod.rs +++ b/lib/src/protocol/mod.rs @@ -1,5 +1,6 @@ pub mod h2; pub mod kawa_h1; +pub mod mux; pub mod pipe; pub mod proxy_protocol; pub mod rustls; diff --git a/lib/src/protocol/mux/mod.rs b/lib/src/protocol/mux/mod.rs new file mode 100644 index 000000000..16476349e --- /dev/null +++ b/lib/src/protocol/mux/mod.rs @@ -0,0 +1,446 @@ +use std::{ + cell::RefCell, + collections::HashMap, + io::Write, + net::SocketAddr, + rc::{Rc, Weak}, +}; + +use mio::{net::TcpStream, Token}; +use rusty_ulid::Ulid; +use sozu_command::ready::Ready; + +mod parser; +mod serializer; + +use crate::{ + https::HttpsListener, + pool::{Checkout, Pool}, + protocol::SessionState, + socket::{FrontRustls, SocketHandler, SocketResult}, + AcceptError, L7Proxy, ProxySession, Readiness, SessionMetrics, SessionResult, StateResult, +}; + +/// Generic Http representation using the Kawa crate using the Checkout of Sozu as buffer +type GenericHttpStream = kawa::Kawa; +type StreamId = usize; +type GlobalStreamId = usize; + +pub enum Position { + Client, + Server, +} + +pub struct ConnectionH1 { + pub socket: Front, + pub position: Position, + pub readiness: Readiness, + pub stream: GlobalStreamId, +} + +pub enum H2State { + ClientPreface, + ServerPreface, + Connected, + Error, +} +pub struct ConnectionH2 { + pub socket: Front, + pub position: Position, + pub readiness: Readiness, + pub state: H2State, + pub streams: HashMap, + // context_hpack: HpackContext, + // settings: SettiongsH2, +} +pub struct Stream { + pub request_id: Ulid, + pub front: GenericHttpStream, + pub back: GenericHttpStream, +} + +pub enum Connection { + H1(ConnectionH1), + H2(ConnectionH2), +} +impl Connection { + fn readiness(&self) -> &Readiness { + match self { + Connection::H1(c) => &c.readiness, + Connection::H2(c) => &c.readiness, + } + } + fn readiness_mut(&mut self) -> &mut Readiness { + match self { + Connection::H1(c) => &mut c.readiness, + Connection::H2(c) => &mut c.readiness, + } + } + fn readable(&mut self, streams: &mut [Stream]) { + match self { + Connection::H1(c) => c.readable(streams), + Connection::H2(c) => c.readable(streams), + } + } + fn writable(&mut self, streams: &mut [Stream]) { + match self { + Connection::H1(c) => c.writable(streams), + Connection::H2(c) => c.writable(streams), + } + } +} + +pub struct Mux { + pub frontend_token: Token, + pub frontend: Connection, + pub backends: HashMap>, + pub streams: Vec, + pub listener: Rc>, + pub pool: Weak>, + pub public_address: SocketAddr, + pub peer_address: Option, + pub sticky_name: String, +} + +impl SessionState for Mux { + fn ready( + &mut self, + session: Rc>, + proxy: Rc>, + metrics: &mut SessionMetrics, + ) -> SessionResult { + let mut counter = 0; + let max_loop_iterations = 100000; + + if self.frontend.readiness().event.is_hup() { + return SessionResult::Close; + } + + let streams = &mut self.streams; + while counter < max_loop_iterations { + let mut dirty = false; + + if self.frontend.readiness().filter_interest().is_readable() { + self.frontend.readable(streams); + dirty = true; + } + + for (_, backend) in self.backends.iter_mut() { + if backend.readiness().filter_interest().is_writable() { + backend.writable(streams); + dirty = true; + } + + if backend.readiness().filter_interest().is_readable() { + backend.readable(streams); + dirty = true; + } + } + + if self.frontend.readiness().filter_interest().is_writable() { + self.frontend.writable(streams); + dirty = true; + } + + for backend in self.backends.values() { + if backend.readiness().filter_interest().is_hup() + || backend.readiness().filter_interest().is_error() + { + return SessionResult::Close; + } + } + + if !dirty { + break; + } + + counter += 1; + } + + if counter == max_loop_iterations { + incr!("http.infinite_loop.error"); + return SessionResult::Close; + } + + SessionResult::Continue + } + + fn update_readiness(&mut self, token: Token, events: sozu_command::ready::Ready) { + if token == self.frontend_token { + self.frontend.readiness_mut().event |= events; + } else if let Some(c) = self.backends.get_mut(&token) { + c.readiness_mut().event |= events; + } + } + + fn timeout(&mut self, token: Token, metrics: &mut SessionMetrics) -> StateResult { + println!("MuxState::timeout({token:?})"); + StateResult::CloseSession + } + + fn cancel_timeouts(&mut self) { + println!("MuxState::cancel_timeouts"); + } + + fn print_state(&self, context: &str) { + error!( + "\ +{} Session(Mux) +\tFrontend: +\t\ttoken: {:?}\treadiness: {:?}", + context, + self.frontend_token, + self.frontend.readiness() + ); + } + fn close(&mut self, _proxy: Rc>, _metrics: &mut SessionMetrics) { + let s = match &mut self.frontend { + Connection::H1(c) => &mut c.socket, + Connection::H2(c) => &mut c.socket, + }; + let mut b = [0; 1024]; + let (size, status) = s.socket_read(&mut b); + println!("{size} {status:?} {:?}", &b[..size]); + } +} + +impl Mux { + // pub fn new( + // frontend_token: Token, + // request_id: Ulid, + // listener: Rc>, + // pool: Weak>, + // public_address: SocketAddr, + // peer_address: Option, + // sticky_name: String, + // ) -> Self { + // Self { + // frontend_token, + // frontend: todo!(), + // backends: todo!(), + // streams: todo!(), + // } + // } + pub fn front_socket(&self) -> &TcpStream { + match &self.frontend { + Connection::H1(c) => &c.socket.stream, + Connection::H2(c) => &c.socket.stream, + } + } + + pub fn create_stream(&mut self, request_id: Ulid) -> Result { + let (front_buffer, back_buffer) = match self.pool.upgrade() { + Some(pool) => { + let mut pool = pool.borrow_mut(); + match (pool.checkout(), pool.checkout()) { + (Some(front_buffer), Some(back_buffer)) => (front_buffer, back_buffer), + _ => return Err(AcceptError::BufferCapacityReached), + } + } + None => return Err(AcceptError::BufferCapacityReached), + }; + self.streams.push(Stream { + request_id, + front: GenericHttpStream::new(kawa::Kind::Request, kawa::Buffer::new(front_buffer)), + back: GenericHttpStream::new(kawa::Kind::Request, kawa::Buffer::new(back_buffer)), + }); + Ok(self.streams.len() - 1) + } +} + +impl ConnectionH2 { + fn readable(&mut self, streams: &mut [Stream]) { + println!("======= MUX H2 READABLE"); + match (&self.state, &self.position) { + (H2State::ClientPreface, Position::Client) => { + error!("Waiting for ClientPreface to finish writing") + } + (H2State::ClientPreface, Position::Server) => { + let stream = &mut streams[0]; + let kawa = &mut stream.front; + let mut i = [0; 33]; + + // let (size, status) = self.socket.socket_read(kawa.storage.space()); + // println!("{:02x?}", &kawa.storage.buffer()[..size]); + // unreachable!(); + + let (size, status) = self.socket.socket_read(&mut i); + + println!("{size} {status:?} {i:02x?}"); + let i = match parser::preface(&i) { + Ok((i, _)) => i, + Err(_) => todo!(), + }; + let header = match parser::frame_header(&i) { + Ok(( + _, + header @ parser::FrameHeader { + payload_len: _, + frame_type: parser::FrameType::Settings, + flags: 0, + stream_id: 0, + }, + )) => header, + _ => todo!(), + }; + let (size, status) = self + .socket + .socket_read(&mut kawa.storage.space()[..header.payload_len as usize]); + kawa.storage.fill(size); + let i = kawa.storage.data(); + println!(" {size} {status:?} {i:02x?}"); + match parser::settings_frame(i, &header) { + Ok((_, settings)) => println!("{settings:?}"), + Err(_) => todo!(), + } + let kawa = &mut stream.back; + self.state = H2State::ServerPreface; + match serializer::gen_frame_header( + kawa.storage.space(), + &parser::FrameHeader { + payload_len: 6 * 2, + frame_type: parser::FrameType::Settings, + flags: 0, + stream_id: 0, + }, + ) { + Ok((_, size)) => kawa.storage.fill(size), + Err(e) => panic!("could not serialize HeaderFrame: {e:?}"), + }; + // kawa.storage + // .write(&[1, 3, 0, 0, 0, 100, 0, 4, 0, 1, 0, 0]) + // .unwrap(); + match serializer::gen_frame_header( + kawa.storage.space(), + &parser::FrameHeader { + payload_len: 0, + frame_type: parser::FrameType::Settings, + flags: 1, + stream_id: 0, + }, + ) { + Ok((_, size)) => kawa.storage.fill(size), + Err(e) => panic!("could not serialize HeaderFrame: {e:?}"), + }; + self.readiness.interest.insert(Ready::WRITABLE); + self.readiness.interest.remove(Ready::READABLE); + } + (H2State::ServerPreface, Position::Client) => todo!("Receive server Settings"), + (H2State::ServerPreface, Position::Server) => { + error!("waiting for ServerPreface to finish writing") + } + (H2State::Connected, Position::Server) => { + let mut header = [0; 9]; + let (size, status) = self.socket.socket_read(&mut header); + println!(" size: {size}, status: {status:?}"); + if size == 0 { + self.readiness.event.remove(Ready::READABLE); + return; + } + println!("{:?}", &header[..size]); + let len = match parser::frame_header(&header) { + Ok((_, h)) => { + println!("{h:?}"); + h.payload_len as usize + } + Err(_) => { + self.state = H2State::Error; + return; + } + }; + let kawa = &mut streams[0].front; + kawa.storage.clear(); + let (size, status) = self.socket.socket_read(&mut kawa.storage.space()[..len]); + kawa.storage.fill(size); + let i = kawa.storage.data(); + println!(" {size} {status:?} {i:?}"); + } + _ => unreachable!(), + } + } + fn writable(&mut self, streams: &mut [Stream]) { + println!("======= MUX H2 WRITABLE"); + match (&self.state, &self.position) { + (H2State::ClientPreface, Position::Client) => todo!("Send PRI + client Settings"), + (H2State::ClientPreface, Position::Server) => unreachable!(), + (H2State::ServerPreface, Position::Client) => unreachable!(), + (H2State::Connected, Position::Server) | (H2State::ServerPreface, Position::Server) => { + let stream = &mut streams[0]; + let kawa = &mut stream.back; + println!("{:?}", kawa.storage.data()); + let (size, status) = self.socket.socket_write(kawa.storage.data()); + println!(" size: {size}, status: {status:?}"); + let size = kawa.storage.available_data(); + kawa.storage.consume(size); + if kawa.storage.is_empty() { + self.readiness.interest.remove(Ready::WRITABLE); + self.readiness.interest.insert(Ready::READABLE); + self.state = H2State::Connected; + } + } + _ => unreachable!(), + } + // for global_stream_id in self.streams.values() { + // let stream = &mut streams[*global_stream_id]; + // let kawa = match self.position { + // Position::Client => &mut stream.back, + // Position::Server => &mut stream.front, + // }; + // kawa.prepare(&mut kawa::h2::BlockConverter); + // let (size, status) = self.socket.socket_write_vectored(&kawa.as_io_slice()); + // println!(" size: {size}, status: {status:?}"); + // kawa.consume(size); + // } + } +} + +impl ConnectionH1 { + fn readable(&mut self, streams: &mut [Stream]) { + println!("======= MUX H1 READABLE"); + let stream = &mut streams[self.stream]; + let kawa = match self.position { + Position::Client => &mut stream.front, + Position::Server => &mut stream.back, + }; + let (size, status) = self.socket.socket_read(kawa.storage.space()); + println!(" size: {size}, status: {status:?}"); + if size > 0 { + kawa.storage.fill(size); + } else { + self.readiness.event.remove(Ready::READABLE); + } + match status { + SocketResult::Continue => {} + SocketResult::Closed => todo!(), + SocketResult::Error => todo!(), + SocketResult::WouldBlock => self.readiness.event.remove(Ready::READABLE), + } + kawa::h1::parse(kawa, &mut kawa::h1::NoCallbacks); + kawa::debug_kawa(kawa); + if kawa.is_terminated() { + self.readiness.interest.remove(Ready::READABLE); + } + } + fn writable(&mut self, streams: &mut [Stream]) { + println!("======= MUX H1 WRITABLE"); + let stream = &mut streams[self.stream]; + let kawa = match self.position { + Position::Client => &mut stream.back, + Position::Server => &mut stream.front, + }; + kawa.prepare(&mut kawa::h1::BlockConverter); + let bufs = kawa.as_io_slice(); + if bufs.is_empty() { + self.readiness.interest.remove(Ready::WRITABLE); + return; + } + let (size, status) = self.socket.socket_write_vectored(&bufs); + println!(" size: {size}, status: {status:?}"); + if size > 0 { + kawa.consume(size); + // self.backend_readiness.interest.insert(Ready::READABLE); + } else { + self.readiness.event.remove(Ready::WRITABLE); + } + } +} diff --git a/lib/src/protocol/mux/parser.rs b/lib/src/protocol/mux/parser.rs new file mode 100644 index 000000000..1b53da321 --- /dev/null +++ b/lib/src/protocol/mux/parser.rs @@ -0,0 +1,468 @@ +use std::convert::From; + +use nom::{ + bytes::streaming::{tag, take}, + combinator::{complete, map, map_opt}, + error::{ErrorKind, ParseError}, + multi::many0, + number::streaming::{be_u16, be_u24, be_u32, be_u8}, + sequence::tuple, + Err, HexDisplay, IResult, Offset, +}; + +#[derive(Clone, Debug, PartialEq)] +pub struct FrameHeader { + pub payload_len: u32, + pub frame_type: FrameType, + pub flags: u8, + pub stream_id: u32, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum FrameType { + Data, + Headers, + Priority, + RstStream, + Settings, + PushPromise, + Ping, + GoAway, + WindowUpdate, + Continuation, +} + +/* +const NO_ERROR: u32 = 0x0; +const PROTOCOL_ERROR: u32 = 0x1; +const INTERNAL_ERROR: u32 = 0x2; +const FLOW_CONTROL_ERROR: u32 = 0x3; +const SETTINGS_TIMEOUT: u32 = 0x4; +const STREAM_CLOSED: u32 = 0x5; +const FRAME_SIZE_ERROR: u32 = 0x6; +const REFUSED_STREAM: u32 = 0x7; +const CANCEL: u32 = 0x8; +const COMPRESSION_ERROR: u32 = 0x9; +const CONNECT_ERROR: u32 = 0xa; +const ENHANCE_YOUR_CALM: u32 = 0xb; +const INADEQUATE_SECURITY: u32 = 0xc; +const HTTP_1_1_REQUIRED: u32 = 0xd; +*/ + +#[derive(Clone, Debug, PartialEq)] +pub struct Error<'a> { + pub input: &'a [u8], + pub error: InnerError, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum InnerError { + Nom(ErrorKind), + NoError, + ProtocolError, + InternalError, + FlowControlError, + SettingsTimeout, + StreamClosed, + FrameSizeError, + RefusedStream, + Cancel, + CompressionError, + ConnectError, + EnhanceYourCalm, + InadequateSecurity, + HTTP11Required, +} + +impl<'a> Error<'a> { + pub fn new(input: &'a [u8], error: InnerError) -> Error<'a> { + Error { input, error } + } +} + +impl<'a> ParseError<&'a [u8]> for Error<'a> { + fn from_error_kind(input: &'a [u8], kind: ErrorKind) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } + + fn append(input: &'a [u8], kind: ErrorKind, other: Self) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } +} + +impl<'a> From<(&'a [u8], ErrorKind)> for Error<'a> { + fn from((input, kind): (&'a [u8], ErrorKind)) -> Self { + Error { + input, + error: InnerError::Nom(kind), + } + } +} + +pub fn preface(i: &[u8]) -> IResult<&[u8], &[u8]> { + tag(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")(i) +} + +// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 +/*named!(pub frame_header, + do_parse!( + payload_len: dbg_dmp!(be_u24) >> + frame_type: map_opt!(be_u8, convert_frame_type) >> + flags: dbg_dmp!(be_u8) >> + stream_id: dbg_dmp!(verify!(be_u32, |id| { + match frame_type { + + } + }) >> + (FrameHeader { payload_len, frame_type, flags, stream_id }) + ) +); + */ + +pub fn frame_header(input: &[u8]) -> IResult<&[u8], FrameHeader, Error> { + let (i1, payload_len) = be_u24(input)?; + let (i2, frame_type) = map_opt(be_u8, convert_frame_type)(i1)?; + let (i3, flags) = be_u8(i2)?; + let (i4, stream_id) = be_u32(i3)?; + + Ok(( + i4, + FrameHeader { + payload_len, + frame_type, + flags, + stream_id, + }, + )) +} + +fn convert_frame_type(t: u8) -> Option { + info!("got frame type: {}", t); + match t { + 0 => Some(FrameType::Data), + 1 => Some(FrameType::Headers), + 2 => Some(FrameType::Priority), + 3 => Some(FrameType::RstStream), + 4 => Some(FrameType::Settings), + 5 => Some(FrameType::PushPromise), + 6 => Some(FrameType::Ping), + 7 => Some(FrameType::GoAway), + 8 => Some(FrameType::WindowUpdate), + 9 => Some(FrameType::Continuation), + _ => None, + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Frame<'a> { + Data(Data<'a>), + Headers(Headers<'a>), + Priority, + RstStream(RstStream), + Settings(Settings), + PushPromise, + Ping(Ping), + GoAway, + WindowUpdate(WindowUpdate), + Continuation, +} + +impl<'a> Frame<'a> { + pub fn is_stream_specific(&self) -> bool { + match self { + Frame::Data(_) + | Frame::Headers(_) + | Frame::Priority + | Frame::RstStream(_) + | Frame::PushPromise + | Frame::Continuation => true, + Frame::Settings(_) | Frame::Ping(_) | Frame::GoAway => false, + Frame::WindowUpdate(w) => w.stream_id != 0, + } + } + + pub fn stream_id(&self) -> u32 { + match self { + Frame::Data(d) => d.stream_id, + Frame::Headers(h) => h.stream_id, + Frame::Priority => unimplemented!(), + Frame::RstStream(r) => r.stream_id, + Frame::PushPromise => unimplemented!(), + Frame::Continuation => unimplemented!(), + Frame::Settings(_) | Frame::Ping(_) | Frame::GoAway => 0, + Frame::WindowUpdate(w) => w.stream_id, + } + } +} + +pub fn frame<'a>(input: &'a [u8], max_frame_size: u32) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (i, header) = frame_header(input)?; + + info!("got frame header: {:?}", header); + + if header.payload_len > max_frame_size { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + + let valid_stream_id = match header.frame_type { + FrameType::Data + | FrameType::Headers + | FrameType::Priority + | FrameType::RstStream + | FrameType::PushPromise + | FrameType::Continuation => header.stream_id != 0, + FrameType::Settings | FrameType::Ping | FrameType::GoAway => header.stream_id == 0, + FrameType::WindowUpdate => true, + }; + + if !valid_stream_id { + return Err(Err::Failure(Error::new(input, InnerError::ProtocolError))); + } + + let f = match header.frame_type { + FrameType::Data => data_frame(i, &header)?, + FrameType::Headers => headers_frame(i, &header)?, + FrameType::Priority => { + if header.payload_len != 5 { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + unimplemented!(); + } + FrameType::RstStream => { + if header.payload_len != 4 { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + rst_stream_frame(i, &header)? + } + FrameType::PushPromise => { + unimplemented!(); + } + FrameType::Continuation => { + unimplemented!(); + } + FrameType::Settings => { + if header.payload_len % 6 != 0 { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + settings_frame(i, &header)? + } + FrameType::Ping => { + if header.payload_len != 8 { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + ping_frame(i, &header)? + } + FrameType::GoAway => { + unimplemented!(); + } + FrameType::WindowUpdate => { + if header.payload_len != 4 { + return Err(Err::Failure(Error::new(input, InnerError::FrameSizeError))); + } + window_update_frame(i, &header)? + } + }; + + Ok(f) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Data<'a> { + pub stream_id: u32, + pub payload: &'a [u8], + pub end_stream: bool, +} + +pub fn data_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + + let (i1, pad_length) = if header.flags & 0x8 != 0 { + let (i, pad_length) = be_u8(i)?; + (i, Some(pad_length)) + } else { + (i, None) + }; + + if pad_length.is_some() && i1.len() <= pad_length.unwrap() as usize { + return Err(Err::Failure(Error::new(input, InnerError::ProtocolError))); + } + + let (_, payload) = take(i1.len() - pad_length.unwrap_or(0) as usize)(i1)?; + + Ok(( + remaining, + Frame::Data(Data { + stream_id: header.stream_id, + payload, + end_stream: header.flags & 0x1 != 0, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Headers<'a> { + pub stream_id: u32, + pub stream_dependency: Option, + pub weight: Option, + pub header_block_fragment: &'a [u8], + pub end_stream: bool, + pub end_headers: bool, + pub priority: bool, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct StreamDependency { + pub exclusive: bool, + pub stream_id: u32, +} + +pub fn headers_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (remaining, i) = take(header.payload_len)(input)?; + + let (i1, pad_length) = if header.flags & 0x8 != 0 { + let (i, pad_length) = be_u8(i)?; + (i, Some(pad_length)) + } else { + (i, None) + }; + + let (i2, stream_dependency) = if header.flags & 0x20 != 0 { + let (i, stream) = map(be_u32, |i| StreamDependency { + exclusive: i & 0x8000 != 0, + stream_id: i & 0x7FFF, + })(i1)?; + (i, Some(stream)) + } else { + (i1, None) + }; + + let (i3, weight) = if header.flags & 0x20 != 0 { + let (i, weight) = be_u8(i2)?; + (i, Some(weight)) + } else { + (i2, None) + }; + + if pad_length.is_some() && i3.len() <= pad_length.unwrap() as usize { + return Err(Err::Failure(Error::new(input, InnerError::ProtocolError))); + } + + let (_, header_block_fragment) = take(i3.len() - pad_length.unwrap_or(0) as usize)(i3)?; + + Ok(( + remaining, + Frame::Headers(Headers { + stream_id: header.stream_id, + stream_dependency, + weight, + header_block_fragment, + end_stream: header.flags & 0x1 != 0, + end_headers: header.flags & 0x4 != 0, + priority: header.flags & 0x20 != 0, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct RstStream { + pub stream_id: u32, + pub error_code: u32, +} + +pub fn rst_stream_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (i, error_code) = be_u32(input)?; + Ok(( + i, + Frame::RstStream(RstStream { + stream_id: header.stream_id, + error_code, + }), + )) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Settings { + pub settings: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Setting { + pub identifier: u16, + pub value: u32, +} + +pub fn settings_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (i, data) = take(header.payload_len)(input)?; + + let (_, settings) = many0(map( + complete(tuple((be_u16, be_u32))), + |(identifier, value)| Setting { identifier, value }, + ))(data)?; + + Ok((i, Frame::Settings(Settings { settings }))) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Ping { + pub payload: [u8; 8], +} + +pub fn ping_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (i, data) = take(8usize)(input)?; + + let mut p = Ping { payload: [0; 8] }; + + for i in 0..8 { + p.payload[i] = data[i]; + } + + Ok((i, Frame::Ping(p))) +} + +#[derive(Clone, Debug, PartialEq)] +pub struct WindowUpdate { + pub stream_id: u32, + pub increment: u32, +} + +pub fn window_update_frame<'a, 'b>( + input: &'a [u8], + header: &'b FrameHeader, +) -> IResult<&'a [u8], Frame<'a>, Error<'a>> { + let (i, increment) = be_u32(input)?; + let increment = increment & 0x7FFF; + + //FIXME: if stream id is 0, trat it as connection error? + if increment == 0 { + return Err(Err::Failure(Error::new(input, InnerError::ProtocolError))); + } + + Ok(( + i, + Frame::WindowUpdate(WindowUpdate { + stream_id: header.stream_id, + increment, + }), + )) +} diff --git a/lib/src/protocol/mux/serializer.rs b/lib/src/protocol/mux/serializer.rs new file mode 100644 index 000000000..ddc6c437c --- /dev/null +++ b/lib/src/protocol/mux/serializer.rs @@ -0,0 +1,37 @@ +use cookie_factory::{ + bytes::{be_u24, be_u32, be_u8}, + gen, + sequence::tuple, + GenError, +}; + +use super::parser::{FrameHeader, FrameType}; + +pub fn gen_frame_header<'a, 'b>( + buf: &'a mut [u8], + frame: &'b FrameHeader, +) -> Result<(&'a mut [u8], usize), GenError> { + let serializer = tuple(( + be_u24(frame.payload_len), + be_u8(serialize_frame_type(&frame.frame_type)), + be_u8(frame.flags), + be_u32(frame.stream_id), + )); + + gen(serializer, buf).map(|(buf, size)| (buf, size as usize)) +} + +pub fn serialize_frame_type(f: &FrameType) -> u8 { + match *f { + FrameType::Data => 0, + FrameType::Headers => 1, + FrameType::Priority => 2, + FrameType::RstStream => 3, + FrameType::Settings => 4, + FrameType::PushPromise => 5, + FrameType::Ping => 6, + FrameType::GoAway => 7, + FrameType::WindowUpdate => 8, + FrameType::Continuation => 9, + } +}