Skip to content

Commit

Permalink
pass target to builder fn
Browse files Browse the repository at this point in the history
  • Loading branch information
hewigovens committed Oct 3, 2024
1 parent 5102cb5 commit 8ae080e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
package.version = "0.3.0"
package.version = "0.3.1"
package.edition = "2021"
package.documentation = "https://docs.rs/reqwest_enum"
package.authors = ["Tao Xu <[email protected]>"]
Expand Down
32 changes: 21 additions & 11 deletions reqwest-enum/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ pub trait JsonRpcProviderType<T: Target>: ProviderType<T> {
}

pub type EndpointFn<T> = fn(target: &T) -> String;
pub type RequestBuilderFn =
fn(request_builder: &reqwest::RequestBuilder) -> reqwest::RequestBuilder;
pub type RequestBuilderFn<T> =
fn(request_builder: &reqwest::RequestBuilder, target: &T) -> reqwest::RequestBuilder;
pub struct Provider<T: Target> {
/// endpoint closure to customize the endpoint (url / path)
endpoint_fn: Option<EndpointFn<T>>,
request_fn: Option<RequestBuilderFn>,
request_fn: Option<RequestBuilderFn<T>>,
client: Client,
}

Expand Down Expand Up @@ -166,7 +166,10 @@ impl<T> Provider<T>
where
T: Target,
{
pub fn new(endpoint_fn: Option<EndpointFn<T>>, request_fn: Option<RequestBuilderFn>) -> Self {
pub fn new(
endpoint_fn: Option<EndpointFn<T>>,
request_fn: Option<RequestBuilderFn<T>>,
) -> Self {
let client = reqwest::Client::new();
Self {
client,
Expand Down Expand Up @@ -206,7 +209,7 @@ where
}
}
if let Some(request_fn) = &self.request_fn {
request = request_fn(&mut request);
request = request_fn(&mut request, target);
}
request
}
Expand All @@ -233,7 +236,9 @@ mod tests {
target::Target,
};
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use tokio_test::block_on;
#[derive(Serialize, Deserialize)]
struct Person {
Expand Down Expand Up @@ -270,7 +275,7 @@ mod tests {
}

fn query(&self) -> HashMap<&'static str, &'static str> {
HashMap::default()
HashMap::from([("foo", "bar")])
}

fn headers(&self) -> HashMap<&'static str, &'static str> {
Expand Down Expand Up @@ -313,11 +318,15 @@ mod tests {
fn test_request_fn() {
let provider = Provider::<HttpBin>::new(
None,
Some(|builder: &reqwest::RequestBuilder| {
builder
.try_clone()
.expect("trying to clone request")
.header("X-test", "test")
Some(|builder: &reqwest::RequestBuilder, target: &HttpBin| {
let mut hasher = DefaultHasher::new();
target.query_string().hash(&mut hasher);
let hash = hasher.finish();

let mut req = builder.try_clone().expect("trying to clone request");
req = req.header("X-test", "test");
req = req.header("X-hash", format!("{}", hash));
req
}),
);

Expand All @@ -326,6 +335,7 @@ mod tests {

assert_eq!(request.method().to_string(), "GET");
assert_eq!(headers.get("X-test").unwrap(), "test");
assert_eq!(headers.get("X-hash").unwrap(), "3270317559611782182");
}

#[test]
Expand Down

0 comments on commit 8ae080e

Please sign in to comment.