Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the mime crate to parse 'content-type' header into Media Types #3649

Merged
merged 8 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ janus_interop_binaries = { version = "0.8.0-prerelease-1", path = "interop_binar
janus_messages = { version = "0.8.0-prerelease-1", path = "messages" }
k8s-openapi = { version = "0.22.0", features = ["v1_26"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
kube = { version = "0.94.2", default-features = false, features = ["client", "rustls-tls"] }
mime = "0.3.17"
mockito = "1.6.1"
num_enum = "0.7.3"
ohttp = { version = "0.5.1", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions aggregator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ janus_core.workspace = true
janus_messages.workspace = true
k8s-openapi.workspace = true
kube.workspace = true
mime.workspace = true
moka = { version = "0.12.10", features = ["future"] }
opentelemetry.workspace = true
opentelemetry-otlp = { workspace = true, features = ["metrics"], optional = true }
Expand Down
30 changes: 20 additions & 10 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use janus_messages::{
AggregationJobInitializeReq, AggregationJobResp, CollectionJobId, CollectionJobReq,
CollectionJobResp, HpkeConfigList, Report, TaskId,
};
use mime::{FromStrError, Mime};
use opentelemetry::{
metrics::{Counter, Meter},
KeyValue,
Expand Down Expand Up @@ -712,17 +713,26 @@ async fn aggregate_shares<C: Clock>(
/// Check the request's Content-Type header, and return an error if it is missing or not equal to
/// the expected value.
fn validate_content_type(conn: &Conn, expected_media_type: &'static str) -> Result<(), Error> {
if let Some(content_type) = conn.request_headers().get(KnownHeaderName::ContentType) {
if content_type != expected_media_type {
Err(Error::BadRequest(format!(
"wrong Content-Type header: {content_type}"
)))
} else {
Ok(())
}
} else {
Err(Error::BadRequest("no Content-Type header".to_owned()))
let content_type = conn
.request_headers()
.get(KnownHeaderName::ContentType)
.ok_or_else(|| Error::BadRequest("no Content-Type header".to_owned()))?;

let mime_str = content_type.as_str().ok_or(Error::BadRequest(format!(
"invalid Content-Type header: {content_type}"
)))?;

let mime: Mime = mime_str.parse().map_err(|e: FromStrError| {
Error::BadRequest(format!("failed to parse Content-Type header: {e}"))
})?;

if mime.essence_str() != expected_media_type {
return Err(Error::BadRequest(format!(
"unexpected Content-Type header: {mime}"
)));
}

Ok(())
}

/// Parse a [`TaskId`] from the "task_id" parameter in a set of path parameter
Expand Down
12 changes: 12 additions & 0 deletions aggregator/src/aggregator/http_handlers/tests/report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ async fn upload_handler() {
assert!(test_conn.take_response_body().is_none());
}

// Upload a report with a versioned media-type header
let mut test_conn = post(task.report_upload_uri().unwrap().path())
.with_request_header(
KnownHeaderName::ContentType,
format!("{};version=07", Report::MEDIA_TYPE),
)
.with_request_body(report.get_encoded().unwrap())
.run_async(&handler)
.await;
assert_eq!(test_conn.status(), Some(Status::Created));
assert!(test_conn.take_response_body().is_none());

let accepted_report_id = report.metadata().id();

// Verify that new reports using an existing report ID are also accepted as a duplicate.
Expand Down
1 change: 1 addition & 0 deletions collector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ fixed = { workspace = true, optional = true }
hpke-dispatch = { workspace = true, features = ["serde"] }
janus_core.workspace = true
janus_messages.workspace = true
mime.workspace = true
prio.workspace = true
rand = { workspace = true, features = ["min_const_gen"] }
reqwest = { workspace = true, features = ["json"] }
Expand Down
11 changes: 10 additions & 1 deletion collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ use janus_messages::{
AggregateShareAad, BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp,
PartialBatchSelector, Query, Role, TaskId,
};
use mime::Mime;
use prio::{
codec::{Decode, Encode, ParameterizedDecode},
vdaf,
Expand Down Expand Up @@ -131,6 +132,8 @@ pub enum Error {
ReportCountOverflow,
#[error("message error: {0}")]
Message(#[from] janus_messages::Error),
#[error("the response from the server was invalid: {0}")]
BadResponse(String),
}

impl From<HttpErrorResponse> for Error {
Expand Down Expand Up @@ -570,7 +573,13 @@ impl<V: vdaf::Collector> Collector<V> {
.headers()
.get(CONTENT_TYPE)
.ok_or(Error::BadContentType(None))?;
if content_type != CollectionJobResp::<TimeInterval>::MEDIA_TYPE {
let mime: Mime = content_type
.to_str()?
.parse()
.map_err(|e: mime::FromStrError| {
Error::BadResponse(format!("failed to parse Content-Type header: {e}"))
})?;
if mime.essence_str() != CollectionJobResp::<TimeInterval>::MEDIA_TYPE {
return Err(Error::BadContentType(Some(content_type.clone())));
}

Expand Down
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ http-api-problem.workspace = true
janus_messages.workspace = true
k8s-openapi = { workspace = true, optional = true }
kube = { workspace = true, optional = true, features = ["rustls-tls"] }
mime.workspace = true
prio = { workspace = true, default-features = true, features = ["experimental"] }
quickcheck = { workspace = true, optional = true }
rand.workspace = true
Expand Down
20 changes: 13 additions & 7 deletions core/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use anyhow::anyhow;
use http::StatusCode;
use http_api_problem::{HttpApiProblem, PROBLEM_JSON_MEDIA_TYPE};
use janus_messages::problem_type::DapProblemType;
use mime;
use reqwest::{header::CONTENT_TYPE, Response};
use std::fmt::{self, Display, Formatter};
use tracing::warn;
Expand All @@ -21,15 +22,20 @@ impl HttpErrorResponse {
/// the response's status code. (see [RFC 7807](https://www.rfc-editor.org/rfc/rfc7807.html))
pub async fn from_response(response: Response) -> Self {
let status = response.status();

if let Some(content_type) = response.headers().get(CONTENT_TYPE) {
if content_type == PROBLEM_JSON_MEDIA_TYPE {
match response.json::<HttpApiProblem>().await {
Ok(mut problem) => {
problem.status = Some(status);
// Unwrap safety: the conversion always succeeds if the status is populated.
return problem.try_into().unwrap();
if let Ok(content_type_str) = content_type.to_str() {
jcjones marked this conversation as resolved.
Show resolved Hide resolved
if let Ok(mime) = content_type_str.parse::<mime::Mime>() {
if mime.essence_str() == PROBLEM_JSON_MEDIA_TYPE {
match response.json::<HttpApiProblem>().await {
Ok(mut problem) => {
problem.status = Some(status);
// Unwrap safety: the conversion always succeeds if the status is populated.
return problem.try_into().unwrap();
}
Err(error) => warn!(%error, "Failed to parse problem details"),
}
}
Err(error) => warn!(%error, "Failed to parse problem details"),
}
}
}
Expand Down
65 changes: 65 additions & 0 deletions core/src/retries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ mod tests {
test_util::install_test_trace_subscriber,
};
use backoff::Notify;
use http_api_problem::PROBLEM_JSON_MEDIA_TYPE;
use reqwest::StatusCode;
use std::time::Duration;
use tokio::net::TcpListener;
Expand Down Expand Up @@ -485,4 +486,68 @@ mod tests {
listener_task.abort();
assert!(listener_task.await.unwrap_err().is_cancelled());
}

#[tokio::test]
async fn http_retry_server_json_problem() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;

let mock_problem = server
.mock("GET", "/")
.with_status(400)
.with_header("Content-Type", PROBLEM_JSON_MEDIA_TYPE)
.with_body("{\"type\":\"evil://league_of_evil/bad.horse?hes.bad\"}")
.expect(1)
.create_async()
.await;

let http_client = reqwest::Client::builder().build().unwrap();

let response = retry_http_request(LimitedRetryer::new(0), || async {
http_client.get(server.url()).send().await
})
.await
.unwrap_err()
.unwrap();

assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(
response.type_uri(),
Some("evil://league_of_evil/bad.horse?hes.bad")
);
mock_problem.assert_async().await;
}

#[tokio::test]
async fn http_retry_server_json_problem_versioned_doc() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;

// Ensure that even with a complex media type, we parse a problem
// document.
let mock_problem = server
.mock("GET", "/")
.with_status(418)
.with_header(
"Content-Type",
"application/problem+json; charset=\"utf-8\"; version=4",
)
.with_body("{\"title\":\"too many eels\"}")
.expect(1)
.create_async()
.await;

let http_client = reqwest::Client::builder().build().unwrap();

let response = retry_http_request(LimitedRetryer::new(0), || async {
http_client.get(server.url()).send().await
})
.await
.unwrap_err()
.unwrap();

assert_eq!(response.status(), StatusCode::IM_A_TEAPOT);
assert_eq!(response.title(), Some("too many eels"));
mock_problem.assert_async().await;
}
}
Loading