From 313a6eb933adb26a5397b0731376a79885a5c896 Mon Sep 17 00:00:00 2001 From: Mehul Date: Mon, 6 Jan 2025 20:41:37 -0500 Subject: [PATCH] refactor: Move retry logic from infer_type_name to wizard --- Cargo.lock | 12 ++++++ Cargo.toml | 2 + src/cli/llm/error.rs | 1 + src/cli/llm/infer_type_name.rs | 68 ++++++++++++++++------------------ src/cli/llm/wizard.rs | 38 +++++++++++++++---- 5 files changed, 77 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e36f881d60..bd8d994061 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5711,6 +5711,7 @@ dependencies = [ "test-log", "thiserror 1.0.69", "tokio", + "tokio-retry", "tokio-test", "tonic 0.11.0", "tonic-types", @@ -6170,6 +6171,17 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" diff --git a/Cargo.toml b/Cargo.toml index 9ee5d9b070..c378b63a27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ headers = "0.3.9" # previous version until hyper is updated to 1+ http = "0.2.12" # previous version until hyper is updated to 1+ insta = { version = "1.38.0", features = ["json"] } tokio = { version = "1.37.0", features = ["rt", "time"] } +tokio-retry = "0.3" reqwest = { version = "0.11", features = [ "json", "rustls-tls", @@ -66,6 +67,7 @@ rustls-pemfile = { version = "1.0.4" } schemars = { version = "0.8.17", features = ["derive"] } hyper = { version = "0.14.28", features = ["server"], default-features = false } tokio = { workspace = true } +tokio-retry = { workspace = true } anyhow = { workspace = true } reqwest = { workspace = true } derive_setters = "0.1.6" diff --git a/src/cli/llm/error.rs b/src/cli/llm/error.rs index c2b44f9ca3..0712266908 100644 --- a/src/cli/llm/error.rs +++ b/src/cli/llm/error.rs @@ -6,6 +6,7 @@ pub enum Error { GenAI(genai::Error), EmptyResponse, Serde(serde_json::Error), + Reqwest(reqwest::Error), } pub type Result = std::result::Result; diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 7183eacbde..2f2812c7d2 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -123,46 +123,40 @@ impl InferTypeName { .collect(), }; - let mut delay = 3; - loop { - let answer = self.wizard.ask(question.clone()).await; - match answer { - Ok(answer) => { - let name = &answer.suggestions.join(", "); - for name in answer.suggestions { - if config.types.contains_key(&name) || used_type_names.contains(&name) { - continue; - } - used_type_names.insert(name.clone()); - new_name_mappings.insert(type_name.to_owned(), name); - break; + // Directly use the wizard's ask method to get a result + let answer = self.wizard.ask(question.clone()).await; + + match answer { + Ok(answer) => { + let name = &answer.suggestions.join(", "); + for name in answer.suggestions { + if config.types.contains_key(&name) || used_type_names.contains(&name) { + continue; } - tracing::info!( - "Suggestions for {}: [{}] - {}/{}", - type_name, - name, - i + 1, - total - ); - - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following - // names: [names list]` + used_type_names.insert(name.clone()); + new_name_mappings.insert(type_name.to_owned(), name); break; } - Err(e) => { - // TODO: log errors after certain number of retries. - if let Error::GenAI(_) = e { - // TODO: retry only when it's required. - tracing::warn!( - "Unable to retrieve a name for the type '{}'. Retrying in {}s", - type_name, - delay - ); - tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; - delay *= std::cmp::min(delay * 2, 60); - } - } + tracing::info!( + "Suggestions for {}: [{}] - {}/{}", + type_name, + name, + i + 1, + total + ); + + // TODO: case where suggested names are already used, then + // extend the base question with + // `suggest different names, we have already used following + // names: [names list]` + } + Err(e) => { + // Handle errors in case of failure + tracing::error!( + "Failed to get suggestions for type '{}': {:?}", + type_name, + e + ); } } } diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 46d7a18624..90a063626c 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -3,8 +3,11 @@ use genai::adapter::AdapterKind; use genai::chat::{ChatOptions, ChatRequest, ChatResponse}; use genai::resolver::AuthResolver; use genai::Client; +use reqwest::StatusCode; +use tokio_retry::strategy::ExponentialBackoff; +use tokio_retry::RetryIf; -use super::Result; +use super::error::{Error, Result}; #[derive(Setters, Clone)] pub struct Wizard { @@ -40,13 +43,34 @@ impl Wizard { pub async fn ask(&self, q: Q) -> Result where - Q: TryInto, + Q: TryInto + Clone, A: TryFrom, { - let response = self - .client - .exec_chat(self.model.as_str(), q.try_into()?, None) - .await?; - A::try_from(response) + let retry_strategy = ExponentialBackoff::from_millis(500) + .max_delay(std::time::Duration::from_secs(30)) + .take(5); + + RetryIf::spawn( + retry_strategy, + || async { + let request = q.clone().try_into()?; // Convert the question to a request + self.client + .exec_chat(self.model.as_str(), request, None) // Execute chat request + .await + .map_err(Error::from) + .and_then(A::try_from) // Convert the response into the + // desired result + }, + |err: &Error| { + // Check if the error is a ReqwestError and if the status is 429 + if let Error::Reqwest(reqwest_err) = err { + if let Some(status) = reqwest_err.status() { + return status == StatusCode::TOO_MANY_REQUESTS; + } + } + false + }, + ) + .await } }