From 27c41258cfa10b5a88598364a6605b8502f79a8e Mon Sep 17 00:00:00 2001 From: Praveen Perera Date: Sun, 13 Oct 2024 15:47:21 -0500 Subject: [PATCH] refactor(async)!: remove async-std dependency, allow custom runtime --- Cargo.toml | 10 +++++++--- src/async.rs | 31 ++++++++++++++++++++++++++----- src/lib.rs | 24 ++++++++++++++++++++---- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fed23d8..e9e0f49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,9 @@ hex = { version = "0.2", package = "hex-conservative" } log = "^0.4" minreq = { version = "2.11.0", features = ["json-using-serde"], optional = true } reqwest = { version = "0.11", features = ["json"], default-features = false, optional = true } -async-std = { version = "1.13.0", optional = true } + +# default async runtime +tokio = { version = "1", features = ["time"], optional = true } [dev-dependencies] serde_json = "1.0" @@ -32,13 +34,15 @@ electrsd = { version = "0.28.0", features = ["legacy", "esplora_a33e97e1", "bitc lazy_static = "1.4.0" [features] -default = ["blocking", "async", "async-https"] +default = ["blocking", "async", "async-https", "tokio"] blocking = ["minreq", "minreq/proxy"] blocking-https = ["blocking", "minreq/https"] blocking-https-rustls = ["blocking", "minreq/https-rustls"] blocking-https-native = ["blocking", "minreq/https-native"] blocking-https-bundled = ["blocking", "minreq/https-bundled"] -async = ["async-std", "reqwest", "reqwest/socks"] + +tokio = ["dep:tokio"] +async = ["reqwest", "reqwest/socks", "tokio?/time"] async-https = ["async", "reqwest/default-tls"] async-https-native = ["async", "reqwest/native-tls"] async-https-rustls = ["async", "reqwest/rustls-tls"] diff --git a/src/async.rs b/src/async.rs index 73bf386..91a28d7 100644 --- a/src/async.rs +++ b/src/async.rs @@ -11,8 +11,8 @@ //! Esplora by way of `reqwest` HTTP client. -use async_std::task; use std::collections::HashMap; +use std::marker::PhantomData; use std::str::FromStr; use bitcoin::consensus::{deserialize, serialize, Decodable, Encodable}; @@ -35,16 +35,19 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct AsyncClient { +pub struct AsyncClient { /// The URL of the Esplora Server. url: String, /// The inner [`reqwest::Client`] to make HTTP requests. client: Client, /// Number of times to retry a request max_retries: usize, + + /// Marker for the type of sleeper used + marker: PhantomData, } -impl AsyncClient { +impl AsyncClient { /// Build an async client from a builder pub fn from_builder(builder: Builder) -> Result { let mut client_builder = Client::builder(); @@ -75,15 +78,16 @@ impl AsyncClient { url: builder.base_url, client: client_builder.build()?, max_retries: builder.max_retries, + marker: PhantomData, }) } - /// Build an async client from the base url and [`Client`] pub fn from_client(url: String, client: Client) -> Self { AsyncClient { url, client, max_retries: crate::DEFAULT_MAX_RETRIES, + marker: PhantomData, } } @@ -460,7 +464,7 @@ impl AsyncClient { loop { match self.client.get(url).send().await? { resp if attempts < self.max_retries && is_status_retryable(resp.status()) => { - task::sleep(delay).await; + S::sleep(delay).await; attempts += 1; delay *= 2; } @@ -473,3 +477,20 @@ impl AsyncClient { fn is_status_retryable(status: reqwest::StatusCode) -> bool { RETRYABLE_ERROR_CODES.contains(&status.as_u16()) } + +pub trait Sleeper: 'static { + type Sleep: std::future::Future; + fn sleep(dur: std::time::Duration) -> Self::Sleep; +} + +#[derive(Debug, Clone, Copy)] +pub struct DefaultSleeper; + +#[cfg(any(test, feature = "tokio"))] +impl Sleeper for DefaultSleeper { + type Sleep = tokio::time::Sleep; + + fn sleep(dur: std::time::Duration) -> Self::Sleep { + tokio::time::sleep(dur) + } +} diff --git a/src/lib.rs b/src/lib.rs index 743f118..0bc7bf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,7 +26,7 @@ //! Here is an example of how to create an asynchronous client. //! //! ```no_run -//! # #[cfg(feature = "async")] +//! # #[cfg(all(feature = "async", feature = "tokio"))] //! # { //! use esplora_client::Builder; //! let builder = Builder::new("https://blockstream.info/testnet/api"); @@ -71,8 +71,10 @@ use std::fmt; use std::num::TryFromIntError; use std::time::Duration; -pub mod api; +#[cfg(feature = "async")] +pub use r#async::Sleeper; +pub mod api; #[cfg(feature = "async")] pub mod r#async; #[cfg(feature = "blocking")] @@ -178,11 +180,18 @@ impl Builder { BlockingClient::from_builder(self) } - // Build an asynchronous client from builder - #[cfg(feature = "async")] + /// Build an asynchronous client from builder + #[cfg(all(feature = "async", feature = "tokio"))] pub fn build_async(self) -> Result { AsyncClient::from_builder(self) } + + /// Build an asynchronous client from builder where the returned client uses a + /// user-defined [`Sleeper`]. + #[cfg(feature = "async")] + pub fn build_async_with_sleeper(self) -> Result, Error> { + AsyncClient::from_builder(self) + } } /// Errors that can happen during a request to `Esplora` servers. @@ -320,8 +329,15 @@ mod test { let blocking_client = builder.build_blocking(); let builder_async = Builder::new(&format!("http://{}", esplora_url)); + + #[cfg(feature = "tokio")] let async_client = builder_async.build_async().unwrap(); + #[cfg(not(feature = "tokio"))] + let async_client = builder_async + .build_async_with_sleeper::() + .unwrap(); + (blocking_client, async_client) }