Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove format! allocations during parsing #47

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl<'r> StreamBuffer<'r> {

pub fn read_field_data(
&mut self,
boundary: &str,
boundary: &[u8],
field_name: Option<&str>,
) -> crate::Result<Option<(bool, Bytes)>> {
log::trace!("finding next field: {:?}", field_name);
Expand All @@ -105,10 +105,9 @@ impl<'r> StreamBuffer<'r> {
return Ok(None);
}

let boundary_deriv = format!("{}{}{}", constants::CRLF, constants::BOUNDARY_EXT, boundary);
let b_len = boundary_deriv.len();
let b_len = boundary.len();

match memchr::memmem::find(&self.buf, boundary_deriv.as_bytes()) {
match memchr::memmem::find(&self.buf, boundary) {
Some(idx) => {
log::trace!("new field found at {}", idx);
let bytes = self.buf.split_to(idx).freeze();
Expand Down Expand Up @@ -139,7 +138,7 @@ impl<'r> StreamBuffer<'r> {
Some(rel_idx) => {
let idx = rel_idx + rem_boundary_part_idx;

match memchr::memmem::find(boundary_deriv.as_bytes(), &self.buf[idx..]) {
match memchr::memmem::find(boundary, &self.buf[idx..]) {
Some(_) => {
let bytes = self.buf.split_to(idx).freeze();

Expand Down
2 changes: 1 addition & 1 deletion src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ impl Stream for Field<'_> {

match state
.buffer
.read_field_data(&state.boundary, state.curr_field_name.as_deref())
.read_field_data(state.boundary.as_bytes_with_crlf(), state.curr_field_name.as_deref())
{
Ok(Some((done, bytes))) => {
state.curr_field_size_counter += bytes.len() as u64;
Expand Down
38 changes: 29 additions & 9 deletions src/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub struct Multipart<'r> {
#[derive(Debug)]
pub(crate) struct MultipartState<'r> {
pub(crate) buffer: StreamBuffer<'r>,
pub(crate) boundary: String,
pub(crate) boundary: Boundary,
pub(crate) stage: StreamingStage,
pub(crate) next_field_idx: usize,
pub(crate) curr_field_name: Option<String>,
Expand All @@ -88,6 +88,29 @@ pub(crate) struct MultipartState<'r> {
pub(crate) constraints: Constraints,
}

#[derive(Debug)]
pub(crate) struct Boundary(String);

impl Boundary {
fn new(boundary: String) -> Self {
let cap = constants::CRLF.len() + constants::BOUNDARY_EXT.len() + boundary.len();
let mut buf = String::with_capacity(cap);
buf.push_str(constants::CRLF);
buf.push_str(constants::BOUNDARY_EXT);
buf.push_str(&boundary);
Self(buf)
Comment on lines +96 to +101
Copy link
Member

@SergioBenitez SergioBenitez Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though this might be better than using format!() directly, the hope would be to remove all unnecessary allocations along the parsing path. Instead of using format!() to create a string to compare, perform the comparison directly. That is, instead of:

"ab" == format!("{}{}", "a", "b");

Do something closer to:

"ab"[..1] == "a" && "ab"[1..] == "b"

But safely, of course, considering all bounds and possible string lengths.

We could imagine a generalized function:

fn string_is_parts(string: &str, parts: &[&str]) -> bool;

string_is_parts("ab", &["a", "b"]);
string_is_parts(maybe_boundary, &[constants::CRLF, constants::BOUNDARY_EXT, boundary]);

}

pub(crate) fn as_bytes(&self) -> &[u8] {
// "--" + boundary
self.0[2..].as_bytes()
}

pub(crate) fn as_bytes_with_crlf(&self) -> &[u8] {
self.0.as_bytes()
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StreamingStage {
FindingFirstBoundary,
Expand Down Expand Up @@ -128,7 +151,7 @@ impl<'r> Multipart<'r> {
Multipart {
state: Arc::new(Mutex::new(MultipartState {
buffer: StreamBuffer::new(stream, constraints.size_limit.whole_stream),
boundary: boundary.into(),
boundary: Boundary::new(boundary.into()),
stage: StreamingStage::FindingFirstBoundary,
next_field_idx: 0,
curr_field_name: None,
Expand Down Expand Up @@ -251,9 +274,7 @@ impl<'r> Multipart<'r> {
}

if state.stage == StreamingStage::FindingFirstBoundary {
let boundary = &state.boundary;
let boundary_deriv = format!("{}{}", constants::BOUNDARY_EXT, boundary);
match state.buffer.read_to(boundary_deriv.as_bytes()) {
match state.buffer.read_to(state.boundary.as_bytes()) {
Some(_) => state.stage = StreamingStage::ReadingBoundary,
None => {
if let Err(err) = state.buffer.poll_stream(cx) {
Expand All @@ -270,7 +291,7 @@ impl<'r> Multipart<'r> {
if state.stage == StreamingStage::ReadingFieldData {
match state
.buffer
.read_field_data(state.boundary.as_str(), state.curr_field_name.as_deref())?
.read_field_data(state.boundary.as_bytes_with_crlf(), state.curr_field_name.as_deref())?
{
Some((done, bytes)) => {
state.curr_field_size_counter += bytes.len() as u64;
Expand All @@ -295,8 +316,7 @@ impl<'r> Multipart<'r> {
}

if state.stage == StreamingStage::ReadingBoundary {
let boundary = &state.boundary;
let boundary_deriv_len = constants::BOUNDARY_EXT.len() + boundary.len();
let boundary_deriv_len = state.boundary.as_bytes().len();

let boundary_bytes = match state.buffer.read_exact(boundary_deriv_len) {
Some(bytes) => bytes,
Expand All @@ -309,7 +329,7 @@ impl<'r> Multipart<'r> {
}
};

if &boundary_bytes[..] == format!("{}{}", constants::BOUNDARY_EXT, boundary).as_bytes() {
if &boundary_bytes[..] == state.boundary.as_bytes() {
state.stage = StreamingStage::DeterminingBoundaryType;
} else {
return Poll::Ready(Err(crate::Error::IncompleteStream));
Expand Down