Skip to content

Commit

Permalink
allow request mutation and async code in prehooks (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
jclulow authored Feb 4, 2024
1 parent a58f336 commit 3ff2ec1
Show file tree
Hide file tree
Showing 27 changed files with 1,412 additions and 683 deletions.
17 changes: 17 additions & 0 deletions example-macro/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,27 @@ generate_api!(
pre_hook = (|_, request| {
println!("doing this {:?}", request);
}),
pre_hook_async = crate::add_auth_headers,
post_hook = crate::all_done,
derives = [schemars::JsonSchema],
);

async fn add_auth_headers(
_: &(),
req: &mut reqwest::Request,
) -> Result<(), reqwest::header::InvalidHeaderValue> {
// You can perform asynchronous, fallible work in a request hook, then
// modify the request right before it is transmitted to the server; e.g.,
// for generating an authenticaiton signature based on the complete set of
// request header values:
req.headers_mut().insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str("legitimate")?,
);

Ok(())
}

fn all_done(_: &(), _result: &reqwest::Result<reqwest::Response>) {}

mod buildomat {
Expand Down
8 changes: 8 additions & 0 deletions progenitor-client/src/progenitor_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,17 @@ pub enum Error<E = ()> {
/// A response not listed in the API description. This may represent a
/// success or failure response; check `status().is_success()`.
UnexpectedResponse(reqwest::Response),

/// An error occurred in the processing of a request pre-hook.
PreHookError(String),
}

impl<E> Error<E> {
/// Returns the status code, if the error was generated from a response.
pub fn status(&self) -> Option<reqwest::StatusCode> {
match self {
Error::InvalidRequest(_) => None,
Error::PreHookError(_) => None,
Error::CommunicationError(e) => e.status(),
Error::ErrorResponse(rv) => Some(rv.status()),
Error::InvalidUpgrade(e) => e.status(),
Expand All @@ -272,6 +276,7 @@ impl<E> Error<E> {
pub fn into_untyped(self) -> Error {
match self {
Error::InvalidRequest(s) => Error::InvalidRequest(s),
Error::PreHookError(s) => Error::PreHookError(s),
Error::CommunicationError(e) => Error::CommunicationError(e),
Error::ErrorResponse(ResponseValue {
inner: _,
Expand Down Expand Up @@ -332,6 +337,9 @@ where
Error::UnexpectedResponse(r) => {
write!(f, "Unexpected Response: {:?}", r)
}
Error::PreHookError(s) => {
write!(f, "Pre-hook Error: {}", s)
}
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions progenitor-impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub struct GenerationSettings {
tag: TagStyle,
inner_type: Option<TokenStream>,
pre_hook: Option<TokenStream>,
pre_hook_async: Option<TokenStream>,
post_hook: Option<TokenStream>,
extra_derives: Vec<String>,

Expand Down Expand Up @@ -128,6 +129,12 @@ impl GenerationSettings {
self
}

/// Hook invoked before issuing the HTTP request.
pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
self.pre_hook_async = Some(pre_hook);
self
}

/// Hook invoked prior to receiving the HTTP response.
pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
self.post_hook = Some(post_hook);
Expand Down
12 changes: 11 additions & 1 deletion progenitor-impl/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,14 @@ impl Generator {
(#hook)(&#client.inner, &#request_ident);
}
});
let pre_hook_async = self.settings.pre_hook_async.as_ref().map(|hook| {
quote! {
match (#hook)(&#client.inner, &mut #request_ident).await {
Ok(_) => (),
Err(e) => return Err(Error::PreHookError(e.to_string())),
}
}
});
let post_hook = self.settings.post_hook.as_ref().map(|hook| {
quote! {
(#hook)(&#client.inner, &#result_ident);
Expand All @@ -1155,7 +1163,8 @@ impl Generator {

#headers_build

let #request_ident = #client.client
#[allow(unused_mut)]
let mut #request_ident = #client.client
. #method_func (#url_ident)
#accept_header
#(#body_func)*
Expand All @@ -1165,6 +1174,7 @@ impl Generator {
.build()?;

#pre_hook
#pre_hook_async
let #result_ident = #client.client
.execute(#request_ident)
.await;
Expand Down
57 changes: 38 additions & 19 deletions progenitor-impl/tests/output/src/buildomat_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<()>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/control/hold", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.post(url)
.header(
Expand Down Expand Up @@ -2438,7 +2439,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<()>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/control/resume", client.baseurl,);
let request = client.client.post(url).build()?;
#[allow(unused_mut)]
let mut request = client.client.post(url).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand Down Expand Up @@ -2484,7 +2486,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -2518,7 +2521,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<Vec<types::Task>>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/tasks", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -2579,7 +2583,8 @@ pub mod builder {
.and_then(|v| types::TaskSubmit::try_from(v).map_err(|e| e.to_string()))
.map_err(Error::InvalidRequest)?;
let url = format!("{}/v1/tasks", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.post(url)
.header(
Expand Down Expand Up @@ -2655,7 +2660,8 @@ pub mod builder {
if let Some(v) = &minseq {
query.push(("minseq", v.to_string()));
}
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -2709,7 +2715,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -2780,7 +2787,8 @@ pub mod builder {
encode_path(&task.to_string()),
encode_path(&output.to_string()),
);
let request = client.client.get(url).build()?;
#[allow(unused_mut)]
let mut request = client.client.get(url).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand Down Expand Up @@ -2834,7 +2842,8 @@ pub mod builder {
.and_then(|v| types::UserCreate::try_from(v).map_err(|e| e.to_string()))
.map_err(Error::InvalidRequest)?;
let url = format!("{}/v1/users", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.post(url)
.header(
Expand Down Expand Up @@ -2869,7 +2878,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<types::WhoamiResult>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/whoami", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -2919,7 +2929,8 @@ pub mod builder {
let Self { client, body } = self;
let body = body.map_err(Error::InvalidRequest)?;
let url = format!("{}/v1/whoami/name", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.put(url)
.header(
Expand Down Expand Up @@ -2981,7 +2992,8 @@ pub mod builder {
.and_then(|v| types::WorkerBootstrap::try_from(v).map_err(|e| e.to_string()))
.map_err(Error::InvalidRequest)?;
let url = format!("{}/v1/worker/bootstrap", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.post(url)
.header(
Expand Down Expand Up @@ -3016,7 +3028,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<types::WorkerPingResult>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/worker/ping", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -3096,7 +3109,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client.client.post(url).json(&body).build()?;
#[allow(unused_mut)]
let mut request = client.client.post(url).json(&body).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand Down Expand Up @@ -3155,7 +3169,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.post(url)
.header(
Expand Down Expand Up @@ -3240,7 +3255,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client.client.post(url).json(&body).build()?;
#[allow(unused_mut)]
let mut request = client.client.post(url).json(&body).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand Down Expand Up @@ -3311,7 +3327,8 @@ pub mod builder {
client.baseurl,
encode_path(&task.to_string()),
);
let request = client.client.post(url).json(&body).build()?;
#[allow(unused_mut)]
let mut request = client.client.post(url).json(&body).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand All @@ -3338,7 +3355,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<types::WorkersResult>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/workers", client.baseurl,);
let request = client
#[allow(unused_mut)]
let mut request = client
.client
.get(url)
.header(
Expand Down Expand Up @@ -3372,7 +3390,8 @@ pub mod builder {
pub async fn send(self) -> Result<ResponseValue<()>, Error<()>> {
let Self { client } = self;
let url = format!("{}/v1/workers/recycle", client.baseurl,);
let request = client.client.post(url).build()?;
#[allow(unused_mut)]
let mut request = client.client.post(url).build()?;
let result = client.client.execute(request).await;
let response = result?;
match response.status().as_u16() {
Expand Down
Loading

0 comments on commit 3ff2ec1

Please sign in to comment.