From e08c14f5b26dfb9190caedf793ce59acf9634773 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Thu, 31 Oct 2024 18:52:28 +0100 Subject: [PATCH] bench(bin/client): don't allocate upload payload upfront (#2200) * bench(bin/client): don't allocate upload payload upfront When POSTing a large request to a server, don't allocate the entire request upfront, but instead, as is done in `neqo-bin/src/server/mod.rs`, iterate over a static buffer. Reuses the same logic from `neqo-bin/src/server/mod.rs`, i.e. `SendData`. See previous similar change on server side https://github.com/mozilla/neqo/pull/2008. * Inline done() --- neqo-bin/src/client/http3.rs | 30 +++-------- neqo-bin/src/lib.rs | 1 + neqo-bin/src/send_data.rs | 78 +++++++++++++++++++++++++++++ neqo-bin/src/server/http09.rs | 14 +++--- neqo-bin/src/server/http3.rs | 21 ++++---- neqo-bin/src/server/mod.rs | 93 +---------------------------------- 6 files changed, 107 insertions(+), 130 deletions(-) create mode 100644 neqo-bin/src/send_data.rs diff --git a/neqo-bin/src/client/http3.rs b/neqo-bin/src/client/http3.rs index 6ce41bef7c..b847f9a5f4 100644 --- a/neqo-bin/src/client/http3.rs +++ b/neqo-bin/src/client/http3.rs @@ -28,7 +28,7 @@ use neqo_transport::{ use url::Url; use super::{get_output_file, qlog_new, Args, CloseState, Res}; -use crate::STREAM_IO_BUFFER_SIZE; +use crate::{send_data::SendData, STREAM_IO_BUFFER_SIZE}; pub struct Handler<'a> { #[allow(clippy::struct_field_names)] @@ -312,9 +312,7 @@ impl StreamHandler for DownloadStreamHandler { } struct UploadStreamHandler { - data: Vec, - offset: usize, - chunk_size: usize, + data: SendData, start: Instant, } @@ -344,21 +342,11 @@ impl StreamHandler for UploadStreamHandler { } fn process_data_writable(&mut self, client: &mut Http3Client, stream_id: StreamId) { - while self.offset < self.data.len() { - let end = self.offset + self.chunk_size.min(self.data.len() - self.offset); - let chunk = &self.data[self.offset..end]; - match client.send_data(stream_id, chunk) { - Ok(amount) => { - if amount == 0 { - break; - } - self.offset += amount; - if self.offset == self.data.len() { - client.stream_close_send(stream_id).unwrap(); - } - } - Err(_) => break, - }; + let done = self + .data + .send(|chunk| client.send_data(stream_id, chunk).unwrap()); + if done { + client.stream_close_send(stream_id).unwrap(); } } } @@ -416,9 +404,7 @@ impl UrlHandler<'_> { Box::new(DownloadStreamHandler { out_file }) } "POST" => Box::new(UploadStreamHandler { - data: vec![42; self.args.upload_size], - offset: 0, - chunk_size: STREAM_IO_BUFFER_SIZE, + data: SendData::zeroes(self.args.upload_size), start: Instant::now(), }), _ => unimplemented!(), diff --git a/neqo-bin/src/lib.rs b/neqo-bin/src/lib.rs index a439183a74..f151f2642e 100644 --- a/neqo-bin/src/lib.rs +++ b/neqo-bin/src/lib.rs @@ -21,6 +21,7 @@ use neqo_transport::{ }; pub mod client; +mod send_data; pub mod server; pub mod udp; diff --git a/neqo-bin/src/send_data.rs b/neqo-bin/src/send_data.rs new file mode 100644 index 0000000000..5634d4c5c8 --- /dev/null +++ b/neqo-bin/src/send_data.rs @@ -0,0 +1,78 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{borrow::Cow, cmp::min}; + +use crate::STREAM_IO_BUFFER_SIZE; + +#[derive(Debug)] +pub struct SendData { + data: Cow<'static, [u8]>, + offset: usize, + remaining: usize, + total: usize, +} + +impl From<&[u8]> for SendData { + fn from(data: &[u8]) -> Self { + Self::from(data.to_vec()) + } +} + +impl From> for SendData { + fn from(data: Vec) -> Self { + let remaining = data.len(); + Self { + total: data.len(), + data: Cow::Owned(data), + offset: 0, + remaining, + } + } +} + +impl From<&str> for SendData { + fn from(data: &str) -> Self { + Self::from(data.as_bytes()) + } +} + +impl SendData { + pub const fn zeroes(total: usize) -> Self { + const MESSAGE: &[u8] = &[0; STREAM_IO_BUFFER_SIZE]; + Self { + data: Cow::Borrowed(MESSAGE), + offset: 0, + remaining: total, + total, + } + } + + fn slice(&self) -> &[u8] { + let end = min(self.data.len(), self.offset + self.remaining); + &self.data[self.offset..end] + } + + pub fn send(&mut self, mut f: impl FnMut(&[u8]) -> usize) -> bool { + while self.remaining > 0 { + match f(self.slice()) { + 0 => { + return false; + } + sent => { + self.remaining -= sent; + self.offset = (self.offset + sent) % self.data.len(); + } + } + } + + self.remaining == 0 + } + + pub const fn len(&self) -> usize { + self.total + } +} diff --git a/neqo-bin/src/server/http09.rs b/neqo-bin/src/server/http09.rs index 1815140b01..1887e3ac6f 100644 --- a/neqo-bin/src/server/http09.rs +++ b/neqo-bin/src/server/http09.rs @@ -15,13 +15,13 @@ use neqo_transport::{ }; use regex::Regex; -use super::{qns_read_response, Args, ResponseData}; -use crate::STREAM_IO_BUFFER_SIZE; +use super::{qns_read_response, Args}; +use crate::{send_data::SendData, STREAM_IO_BUFFER_SIZE}; #[derive(Default)] struct HttpStreamState { writable: bool, - data_to_send: Option, + data_to_send: Option, } pub struct HttpServer { @@ -127,7 +127,7 @@ impl HttpServer { return; }; - let resp: ResponseData = { + let resp: SendData = { let path = path.as_str(); qdebug!("Path = '{path}'"); if self.is_qns_test { @@ -140,7 +140,7 @@ impl HttpServer { } } else { let count = path.parse().unwrap(); - ResponseData::zeroes(count) + SendData::zeroes(count) } }; @@ -173,8 +173,8 @@ impl HttpServer { stream_state.writable = true; if let Some(resp) = &mut stream_state.data_to_send { - resp.send_h09(stream_id, conn); - if resp.done() { + let done = resp.send(|chunk| conn.borrow_mut().stream_send(stream_id, chunk).unwrap()); + if done { conn.borrow_mut().stream_close_send(stream_id).unwrap(); self.write_state.remove(&stream_id); } else { diff --git a/neqo-bin/src/server/http3.rs b/neqo-bin/src/server/http3.rs index 3506387a62..dfef3f1be4 100644 --- a/neqo-bin/src/server/http3.rs +++ b/neqo-bin/src/server/http3.rs @@ -19,12 +19,13 @@ use neqo_http3::{ }; use neqo_transport::{server::ValidateAddress, ConnectionIdGenerator}; -use super::{qns_read_response, Args, ResponseData}; +use super::{qns_read_response, Args}; +use crate::send_data::SendData; pub struct HttpServer { server: Http3Server, /// Progress writing to each stream. - remaining_data: HashMap, + remaining_data: HashMap, posts: HashMap, is_qns_test: bool, } @@ -110,7 +111,7 @@ impl super::HttpServer for HttpServer { let mut response = if self.is_qns_test { match qns_read_response(path.value()) { - Ok(data) => ResponseData::from(data), + Ok(data) => SendData::from(data), Err(e) => { qerror!("Failed to read {}: {e}", path.value()); stream @@ -123,19 +124,19 @@ impl super::HttpServer for HttpServer { } else if let Ok(count) = path.value().trim_matches(|p| p == '/').parse::() { - ResponseData::zeroes(count) + SendData::zeroes(count) } else { - ResponseData::from(path.value()) + SendData::from(path.value()) }; stream .send_headers(&[ Header::new(":status", "200"), - Header::new("content-length", response.remaining.to_string()), + Header::new("content-length", response.len().to_string()), ]) .unwrap(); - response.send_h3(&stream); - if response.done() { + let done = response.send(|chunk| stream.send_data(chunk).unwrap()); + if done { stream.stream_close_send().unwrap(); } else { self.remaining_data.insert(stream.stream_id(), response); @@ -144,8 +145,8 @@ impl super::HttpServer for HttpServer { Http3ServerEvent::DataWritable { stream } => { if self.posts.get_mut(&stream).is_none() { if let Some(remaining) = self.remaining_data.get_mut(&stream.stream_id()) { - remaining.send_h3(&stream); - if remaining.done() { + let done = remaining.send(|chunk| stream.send_data(chunk).unwrap()); + if done { self.remaining_data.remove(&stream.stream_id()); stream.stream_close_send().unwrap(); } diff --git a/neqo-bin/src/server/mod.rs b/neqo-bin/src/server/mod.rs index 27e7ab3d95..8927890e8e 100644 --- a/neqo-bin/src/server/mod.rs +++ b/neqo-bin/src/server/mod.rs @@ -7,9 +7,7 @@ #![allow(clippy::future_not_send)] use std::{ - borrow::Cow, cell::RefCell, - cmp::min, fmt::{self, Display}, fs, io, net::{SocketAddr, ToSocketAddrs}, @@ -30,11 +28,10 @@ use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, init_db, AntiReplay, Cipher, }; -use neqo_http3::{Http3OrWebTransportStream, StreamId}; -use neqo_transport::{server::ConnectionRef, Output, RandomConnectionIdGenerator, Version}; +use neqo_transport::{Output, RandomConnectionIdGenerator, Version}; use tokio::time::Sleep; -use crate::{SharedArgs, STREAM_IO_BUFFER_SIZE}; +use crate::SharedArgs; const ANTI_REPLAY_WINDOW: Duration = Duration::from_secs(10); @@ -409,89 +406,3 @@ pub async fn server(mut args: Args) -> Res<()> { .run() .await } - -#[derive(Debug)] -struct ResponseData { - data: Cow<'static, [u8]>, - offset: usize, - remaining: usize, -} - -impl From<&[u8]> for ResponseData { - fn from(data: &[u8]) -> Self { - Self::from(data.to_vec()) - } -} - -impl From> for ResponseData { - fn from(data: Vec) -> Self { - let remaining = data.len(); - Self { - data: Cow::Owned(data), - offset: 0, - remaining, - } - } -} - -impl From<&str> for ResponseData { - fn from(data: &str) -> Self { - Self::from(data.as_bytes()) - } -} - -impl ResponseData { - const fn zeroes(total: usize) -> Self { - const MESSAGE: &[u8] = &[0; STREAM_IO_BUFFER_SIZE]; - Self { - data: Cow::Borrowed(MESSAGE), - offset: 0, - remaining: total, - } - } - - fn slice(&self) -> &[u8] { - let end = min(self.data.len(), self.offset + self.remaining); - &self.data[self.offset..end] - } - - fn send_h3(&mut self, stream: &Http3OrWebTransportStream) { - while self.remaining > 0 { - match stream.send_data(self.slice()) { - Ok(0) => { - return; - } - Ok(sent) => { - self.remaining -= sent; - self.offset = (self.offset + sent) % self.data.len(); - } - Err(e) => { - qwarn!("Error writing to stream {}: {:?}", stream, e); - return; - } - } - } - } - - fn send_h09(&mut self, stream_id: StreamId, conn: &ConnectionRef) { - while self.remaining > 0 { - match conn - .borrow_mut() - .stream_send(stream_id, self.slice()) - .unwrap() - { - 0 => { - return; - } - sent => { - self.remaining -= sent; - self.offset = (self.offset + sent) % self.data.len(); - } - } - } - } - - const fn done(&self) -> bool { - self.remaining == 0 - } -}