Skip to content

Commit

Permalink
make it clear in HttpTransparentRequest::retain_upgrade that somethin…
Browse files Browse the repository at this point in the history
…g happend (#349)

as requested by zh-jq in #341
  • Loading branch information
GlenDC authored Oct 10, 2024
1 parent 8ec28d7 commit 90c721e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion g3proxy/src/inspect/http/v1/upgrade/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ where
true
});

if upgrade_token_count == 0 {
if upgrade_token_count == Some(0) {
let rsp = HttpProxyClientResponse::forbidden(self.req.version);
self.should_close = true;
if rsp.reply_err_to_request(clt_w).await.is_ok() {
Expand Down
10 changes: 7 additions & 3 deletions lib/g3-http/src/server/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,12 @@ impl HttpTransparentRequest {
Ok(())
}

pub fn retain_upgrade<F>(&mut self, retain: F) -> usize
pub fn retain_upgrade<F>(&mut self, retain: F) -> Option<usize>
where
F: Fn(HttpUpgradeToken) -> bool,
{
let mut new_upgrade_headers = Vec::new();
let mut headers_found = false;
for header in self.hop_by_hop_headers.get_all(header::UPGRADE) {
let value = header.to_str();
for s in value.split(',') {
Expand All @@ -305,6 +306,7 @@ impl HttpTransparentRequest {
let Ok(protocol) = HttpUpgradeToken::from_str(s) else {
continue;
};
headers_found = true;
if retain(protocol) {
let mut new_value =
unsafe { HttpHeaderValue::from_string_unchecked(s.to_string()) };
Expand All @@ -321,7 +323,7 @@ impl HttpTransparentRequest {
for value in new_upgrade_headers {
self.hop_by_hop_headers.append(header::UPGRADE, value);
}
retain_count
headers_found.then_some(retain_count)
}

fn insert_hop_by_hop_header(
Expand Down Expand Up @@ -593,7 +595,9 @@ mod tests {
let (mut request, _) = HttpTransparentRequest::parse(&mut buf_stream, 4096, false)
.await
.unwrap();
let left_tokens = request.retain_upgrade(|p| matches!(p, HttpUpgradeToken::Http(_)));
let left_tokens = request
.retain_upgrade(|p| matches!(p, HttpUpgradeToken::Http(_)))
.unwrap();
assert_eq!(left_tokens, 1);
let token = request.hop_by_hop_headers.get(header::UPGRADE).unwrap();
assert_eq!(token.to_str(), "HTTP/2.0");
Expand Down

0 comments on commit 90c721e

Please sign in to comment.