diff --git a/src/wasm/client.rs b/src/wasm/client.rs index e33ed26d1..c8f8106e5 100644 --- a/src/wasm/client.rs +++ b/src/wasm/client.rs @@ -222,7 +222,10 @@ async fn fetch(req: Request) -> crate::Result { } } - let abort = AbortGuard::new()?; + let mut abort = AbortGuard::new()?; + if let Some(timeout) = req.timeout() { + abort.timeout(*timeout); + } init.signal(Some(&abort.signal())); let js_req = web_sys::Request::new_with_str_and_init(req.url().as_str(), &init) diff --git a/src/wasm/mod.rs b/src/wasm/mod.rs index e99fb11fb..852bb9a6c 100644 --- a/src/wasm/mod.rs +++ b/src/wasm/mod.rs @@ -1,4 +1,9 @@ -use wasm_bindgen::JsCast; +use std::convert::TryInto; +use std::time::Duration; + +use js_sys::Function; +use wasm_bindgen::prelude::{wasm_bindgen, Closure}; +use wasm_bindgen::{JsCast, JsValue}; use web_sys::{AbortController, AbortSignal}; mod body; @@ -14,6 +19,15 @@ pub use self::client::{Client, ClientBuilder}; pub use self::request::{Request, RequestBuilder}; pub use self::response::Response; +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_name = "setTimeout")] + fn set_timeout(handler: &Function, timeout: i32) -> JsValue; + + #[wasm_bindgen(js_name = "clearTimeout")] + fn clear_timeout(handle: JsValue) -> JsValue; +} + async fn promise(promise: js_sys::Promise) -> Result where T: JsCast, @@ -30,6 +44,7 @@ where /// A guard that cancels a fetch request when dropped. struct AbortGuard { ctrl: AbortController, + timeout: Option<(JsValue, Closure)>, } impl AbortGuard { @@ -38,16 +53,32 @@ impl AbortGuard { ctrl: AbortController::new() .map_err(crate::error::wasm) .map_err(crate::error::builder)?, + timeout: None, }) } fn signal(&self) -> AbortSignal { self.ctrl.signal() } + + fn timeout(&mut self, timeout: Duration) { + let ctrl = self.ctrl.clone(); + let abort = Closure::once(move || ctrl.abort()); + let timeout = set_timeout( + abort.as_ref().unchecked_ref::(), + timeout.as_millis().try_into().expect("timeout"), + ); + if let Some((id, _)) = self.timeout.replace((timeout, abort)) { + clear_timeout(id); + } + } } impl Drop for AbortGuard { fn drop(&mut self) { self.ctrl.abort(); + if let Some((id, _)) = self.timeout.take() { + clear_timeout(id); + } } } diff --git a/src/wasm/request.rs b/src/wasm/request.rs index e6f51ebc1..494fdcb8b 100644 --- a/src/wasm/request.rs +++ b/src/wasm/request.rs @@ -1,5 +1,6 @@ use std::convert::TryFrom; use std::fmt; +use std::time::Duration; use bytes::Bytes; use http::{request::Parts, Method, Request as HttpRequest}; @@ -18,6 +19,7 @@ pub struct Request { url: Url, headers: HeaderMap, body: Option, + timeout: Option, pub(super) cors: bool, pub(super) credentials: Option, } @@ -37,6 +39,7 @@ impl Request { url, headers: HeaderMap::new(), body: None, + timeout: None, cors: true, credentials: None, } @@ -90,6 +93,18 @@ impl Request { &mut self.body } + /// Get the timeout. + #[inline] + pub fn timeout(&self) -> Option<&Duration> { + self.timeout.as_ref() + } + + /// Get a mutable reference to the timeout. + #[inline] + pub fn timeout_mut(&mut self) -> &mut Option { + &mut self.timeout + } + /// Attempts to clone the `Request`. /// /// None is returned if a body is which can not be cloned. @@ -104,6 +119,7 @@ impl Request { url: self.url.clone(), headers: self.headers.clone(), body, + timeout: self.timeout.clone(), cors: self.cors, credentials: self.credentials, }) @@ -241,6 +257,14 @@ impl RequestBuilder { self } + /// Enables a request timeout. + pub fn timeout(mut self, timeout: Duration) -> RequestBuilder { + if let Ok(ref mut req) = self.request { + *req.timeout_mut() = Some(timeout); + } + self + } + /// TODO #[cfg(feature = "multipart")] #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] @@ -466,6 +490,7 @@ where url, headers, body: Some(body.into()), + timeout: None, cors: true, credentials: None, }) diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 71dc0ce66..c7d3c95ac 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -65,7 +65,6 @@ async fn request_timeout() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } -#[cfg(not(target_arch = "wasm32"))] #[tokio::test] async fn connect_timeout() { let _ = env_logger::try_init(); diff --git a/tests/wasm_simple.rs b/tests/wasm_simple.rs index fe314de4d..86df233d5 100644 --- a/tests/wasm_simple.rs +++ b/tests/wasm_simple.rs @@ -1,4 +1,5 @@ #![cfg(target_arch = "wasm32")] +use std::time::Duration; use wasm_bindgen::prelude::*; use wasm_bindgen_test::*; @@ -22,3 +23,17 @@ async fn simple_example() { let body = res.text().await.expect("response to utf-8 text"); log(&format!("Body:\n\n{body}")); } + +#[wasm_bindgen_test] +async fn request_with_timeout() { + let client = reqwest::Client::new(); + let err = client + .get("https://hyper.rs") + .timeout(Duration::from_millis(1)) + .send() + .await + .expect_err("Expected error from aborted request"); + + assert!(err.is_request()); + assert!(format!("{err:?}").contains("AbortError"), "{err:?}"); +}