diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 9256009e..301f370c 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -72,6 +72,10 @@ pub enum HostError { /// For task manager errors. #[error("There was an error with the task manager: {0}")] TaskManager(#[from] TaskManagerError), + + /// For system paused state. + #[error("System is paused")] + SystemPaused, } impl IntoResponse for HostError { @@ -91,6 +95,7 @@ impl IntoResponse for HostError { HostError::Anyhow(e) => ("anyhow_error", e.to_string()), HostError::HandleDropped => ("handle_dropped", "".to_owned()), HostError::CapacityFull => ("capacity_full", "".to_owned()), + HostError::SystemPaused => ("system_paused", "".to_owned()), }; let status = Status::Error { error: error.to_owned(), @@ -130,6 +135,7 @@ impl From for TaskStatus { HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()), HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()), HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()), + HostError::SystemPaused => TaskStatus::SystemPaused, } } } @@ -151,6 +157,7 @@ impl From<&HostError> for TaskStatus { HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()), HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()), HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()), + HostError::SystemPaused => TaskStatus::SystemPaused, } } } diff --git a/host/src/lib.rs b/host/src/lib.rs index 93c6719a..930bc0d0 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::{alloc, path::PathBuf}; use anyhow::Context; @@ -108,10 +110,11 @@ impl Opts { /// Read the options from a file and merge it with the current options. pub fn merge_from_file(&mut self) -> HostResult<()> { - let file = std::fs::File::open(&self.config_path)?; + let file = std::fs::File::open(&self.config_path).context("Failed to open config file")?; let reader = std::io::BufReader::new(file); - let mut config: Value = serde_json::from_reader(reader)?; - let this = serde_json::to_value(&self)?; + let mut config: Value = + serde_json::from_reader(reader).context("Failed to read config file")?; + let this = serde_json::to_value(&self).context("Failed to deserialize Opts")?; merge(&mut config, &this); *self = serde_json::from_value(config)?; @@ -150,15 +153,17 @@ pub struct ProverState { pub opts: Opts, pub chain_specs: SupportedChainSpecs, pub task_channel: mpsc::Sender, + pause_flag: Arc, } -#[derive(Debug, Serialize)] +#[derive(Debug)] pub enum Message { Cancel(ProofTaskDescriptor), Task(ProofRequest), TaskComplete(ProofRequest), CancelAggregate(AggregationOnlyRequest), Aggregate(AggregationOnlyRequest), + SystemPause(tokio::sync::oneshot::Sender>), } impl ProverState { @@ -188,6 +193,7 @@ impl ProverState { } let (task_channel, receiver) = mpsc::channel::(opts.concurrency_limit); + let pause_flag = Arc::new(AtomicBool::new(false)); let opts_clone = opts.clone(); let chain_specs_clone = chain_specs.clone(); @@ -202,6 +208,7 @@ impl ProverState { opts, chain_specs, task_channel, + pause_flag, }) } @@ -212,6 +219,30 @@ impl ProverState { pub fn request_config(&self) -> ProofRequestOpt { self.opts.proof_request_opt.clone() } + + pub fn is_paused(&self) -> bool { + self.pause_flag.load(Ordering::SeqCst) + } + + /// Set the pause flag and notify the task manager to pause, then wait for the task manager to + /// finish the pause process. + /// + /// Note that this function is blocking until the task manager finishes the pause process. + pub async fn set_pause(&self, paused: bool) -> HostResult<()> { + self.pause_flag.store(paused, Ordering::SeqCst); + if paused { + // Notify task manager to start pause process + let (sender, receiver) = tokio::sync::oneshot::channel(); + self.task_channel + .try_send(Message::SystemPause(sender)) + .context("Failed to send pause message")?; + + // Wait for the pause message to be processed + let result = receiver.await.context("Failed to receive pause message")?; + return result; + } + Ok(()) + } } #[global_allocator] diff --git a/host/src/proof.rs b/host/src/proof.rs index 1223af0c..d01de70e 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -302,6 +302,10 @@ impl ProofActor { .expect("Couldn't acquire permit"); self.run_aggregate(request, permit).await; } + Message::SystemPause(notifier) => { + let result = self.handle_system_pause().await; + let _ = notifier.send(result); + } } } } @@ -384,6 +388,90 @@ impl ProofActor { Ok(()) } + + async fn cancel_all_running_tasks(&mut self) -> HostResult<()> { + info!("Cancelling all running tasks"); + + // Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other + // internal functions. + let running_tasks = { + let running_tasks = self.running_tasks.lock().await; + (*running_tasks).clone() + }; + + // Cancel all running tasks, don't stop even if any task fails. + let mut final_result = Ok(()); + for proof_task_descriptor in running_tasks.keys() { + match self.cancel_task(proof_task_descriptor.clone()).await { + Ok(()) => { + info!( + "Cancel task during system pause, task: {:?}", + proof_task_descriptor + ); + } + Err(e) => { + error!( + "Failed to cancel task during system pause: {}, task: {:?}", + e, proof_task_descriptor + ); + final_result = final_result.and(Err(e)); + } + } + } + final_result + } + + async fn cancel_all_aggregation_tasks(&mut self) -> HostResult<()> { + info!("Cancelling all aggregation tasks"); + + // Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other + // internal functions. + let aggregate_tasks = { + let aggregate_tasks = self.aggregate_tasks.lock().await; + (*aggregate_tasks).clone() + }; + + // Cancel all aggregation tasks, don't stop even if any task fails. + let mut final_result = Ok(()); + for request in aggregate_tasks.keys() { + match self.cancel_aggregation_task(request.clone()).await { + Ok(()) => { + info!( + "Cancel aggregation task during system pause, task: {}", + request + ); + } + Err(e) => { + error!( + "Failed to cancel aggregation task during system pause: {}, task: {}", + e, request + ); + final_result = final_result.and(Err(e)); + } + } + } + final_result + } + + async fn handle_system_pause(&mut self) -> HostResult<()> { + info!("System pausing"); + + let mut final_result = Ok(()); + + self.pending_tasks.lock().await.clear(); + + if let Err(e) = self.cancel_all_running_tasks().await { + final_result = final_result.and(Err(e)); + } + + if let Err(e) = self.cancel_all_aggregation_tasks().await { + final_result = final_result.and(Err(e)); + } + + // TODO(Kero): make sure all tasks are saved to database, including pending tasks. + + final_result + } } pub async fn handle_proof( @@ -483,3 +571,168 @@ pub async fn handle_proof( Ok(proof) } + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_handle_system_pause_happy_path() { + let (tx, rx) = mpsc::channel(100); + let mut actor = setup_actor_with_tasks(tx, rx); + + let result = actor.handle_system_pause().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_system_pause_with_pending_tasks() { + let (tx, rx) = mpsc::channel(100); + let mut actor = setup_actor_with_tasks(tx, rx); + + // Add some pending tasks + actor.pending_tasks.lock().await.push_back(ProofRequest { + block_number: 1, + l1_inclusion_block_number: 1, + network: "test".to_string(), + l1_network: "test".to_string(), + graffiti: B256::ZERO, + prover: Default::default(), + proof_type: Default::default(), + blob_proof_type: Default::default(), + prover_args: HashMap::new(), + }); + + let result = actor.handle_system_pause().await; + assert!(result.is_ok()); + + // Verify pending tasks were cleared + assert_eq!(actor.pending_tasks.lock().await.len(), 0); + } + + #[tokio::test] + async fn test_handle_system_pause_with_running_tasks() { + let (tx, rx) = mpsc::channel(100); + let mut actor = setup_actor_with_tasks(tx, rx); + + // Add some running tasks + let task_descriptor = ProofTaskDescriptor::default(); + let cancellation_token = CancellationToken::new(); + actor + .running_tasks + .lock() + .await + .insert(task_descriptor.clone(), cancellation_token.clone()); + + let result = actor.handle_system_pause().await; + assert!(result.is_ok()); + + // Verify running tasks were cancelled + assert!(cancellation_token.is_cancelled()); + + // TODO(Kero): Cancelled tasks should be removed from running_tasks + // assert_eq!(actor.running_tasks.lock().await.len(), 0); + } + + #[tokio::test] + async fn test_handle_system_pause_with_aggregation_tasks() { + let (tx, rx) = mpsc::channel(100); + let mut actor = setup_actor_with_tasks(tx, rx); + + // Add some aggregation tasks + let request = AggregationOnlyRequest::default(); + let cancellation_token = CancellationToken::new(); + actor + .aggregate_tasks + .lock() + .await + .insert(request.clone(), cancellation_token.clone()); + + let result = actor.handle_system_pause().await; + assert!(result.is_ok()); + + // Verify aggregation tasks were cancelled + assert!(cancellation_token.is_cancelled()); + // TODO(Kero): Cancelled tasks should be removed from aggregate_tasks + // assert_eq!(actor.aggregate_tasks.lock().await.len(), 0); + } + + #[tokio::test] + async fn test_handle_system_pause_with_failures() { + let (tx, rx) = mpsc::channel(100); + let mut actor = setup_actor_with_tasks(tx, rx); + + // Add some pending tasks + { + actor.pending_tasks.lock().await.push_back(ProofRequest { + block_number: 1, + l1_inclusion_block_number: 1, + network: "test".to_string(), + l1_network: "test".to_string(), + graffiti: B256::ZERO, + prover: Default::default(), + proof_type: Default::default(), + blob_proof_type: Default::default(), + prover_args: HashMap::new(), + }); + } + + let good_running_task_token = { + // Add some running tasks + let task_descriptor = ProofTaskDescriptor::default(); + let cancellation_token = CancellationToken::new(); + actor + .running_tasks + .lock() + .await + .insert(task_descriptor.clone(), cancellation_token.clone()); + cancellation_token + }; + + let good_aggregation_task_token = { + // Add some aggregation tasks + let request = AggregationOnlyRequest::default(); + let cancellation_token = CancellationToken::new(); + actor + .aggregate_tasks + .lock() + .await + .insert(request.clone(), cancellation_token.clone()); + cancellation_token + }; + + // Setup tasks that will fail to cancel + { + let task_descriptor_should_fail_cause_not_supported_error = ProofTaskDescriptor { + proof_system: ProofType::Risc0, + ..Default::default() + }; + actor.running_tasks.lock().await.insert( + task_descriptor_should_fail_cause_not_supported_error.clone(), + CancellationToken::new(), + ); + } + + let result = actor.handle_system_pause().await; + + // Verify error contains all accumulated errors + assert!(matches!( + result, + Err(HostError::Core(RaikoError::FeatureNotSupportedError(..))) + )); + assert!(good_running_task_token.is_cancelled()); + assert!(good_aggregation_task_token.is_cancelled()); + assert!(actor.pending_tasks.lock().await.is_empty()); + } + + // Helper function to setup actor with common test configuration + fn setup_actor_with_tasks(tx: Sender, rx: Receiver) -> ProofActor { + let opts = Opts { + concurrency_limit: 4, + ..Default::default() + }; + + ProofActor::new(tx, rx, opts, SupportedChainSpecs::default()) + } +} diff --git a/host/src/server/api/admin.rs b/host/src/server/api/admin.rs new file mode 100644 index 00000000..948e59f7 --- /dev/null +++ b/host/src/server/api/admin.rs @@ -0,0 +1,114 @@ +use axum::{extract::State, routing::post, Router}; + +use crate::{interfaces::HostResult, ProverState}; + +pub fn create_router() -> Router { + Router::new() + .route("/admin/pause", post(pause)) + .route("/admin/unpause", post(unpause)) +} + +async fn pause(State(state): State) -> HostResult<&'static str> { + state.set_pause(true).await?; + Ok("System paused successfully") +} + +async fn unpause(State(state): State) -> HostResult<&'static str> { + state.set_pause(false).await?; + Ok("System unpaused successfully") +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use clap::Parser; + use std::path::PathBuf; + use tower::ServiceExt; + + #[tokio::test] + async fn test_pause() { + let opts = { + let mut opts = crate::Opts::parse(); + opts.config_path = PathBuf::from("../host/config/config.json"); + opts.merge_from_file().unwrap(); + opts + }; + let state = ProverState::init_with_opts(opts).unwrap(); + let app = Router::new() + .route("/admin/pause", post(pause)) + .with_state(state.clone()); + + let request = Request::builder() + .method("POST") + .uri("/admin/pause") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert!(state.is_paused()); + } + + #[tokio::test] + async fn test_pause_when_already_paused() { + let opts = { + let mut opts = crate::Opts::parse(); + opts.config_path = PathBuf::from("../host/config/config.json"); + opts.merge_from_file().unwrap(); + opts + }; + let state = ProverState::init_with_opts(opts).unwrap(); + + state.set_pause(true).await.unwrap(); + + let app = Router::new() + .route("/admin/pause", post(pause)) + .with_state(state.clone()); + + let request = Request::builder() + .method("POST") + .uri("/admin/pause") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert!(state.is_paused()); + } + + #[tokio::test] + async fn test_unpause() { + let opts = { + let mut opts = crate::Opts::parse(); + opts.config_path = PathBuf::from("../host/config/config.json"); + opts.merge_from_file().unwrap(); + opts + }; + let state = ProverState::init_with_opts(opts).unwrap(); + + // Set initial paused state + state.set_pause(true).await.unwrap(); + assert!(state.is_paused()); + + let app = Router::new() + .route("/admin/unpause", post(unpause)) + .with_state(state.clone()); + + let request = Request::builder() + .method("POST") + .uri("/admin/unpause") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert!(!state.is_paused()); + } +} diff --git a/host/src/server/api/mod.rs b/host/src/server/api/mod.rs index 45be92f1..06a0b7c0 100644 --- a/host/src/server/api/mod.rs +++ b/host/src/server/api/mod.rs @@ -16,6 +16,8 @@ use tower_http::{ use crate::ProverState; +pub mod admin; +pub mod util; pub mod v1; pub mod v2; pub mod v3; @@ -39,12 +41,14 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout let v1_api = v1::create_router(concurrency_limit); let v2_api = v2::create_router(); let v3_api = v3::create_router(); + let admin_api = admin::create_router(); let router = Router::new() .nest("/v1", v1_api) .nest("/v2", v2_api) .nest("/v3", v3_api.clone()) .merge(v3_api) + .nest("/admin", admin_api) .layer(middleware) .layer(middleware::from_fn(check_max_body_size)) .layer(trace) diff --git a/host/src/server/api/util.rs b/host/src/server/api/util.rs new file mode 100644 index 00000000..d47c1da2 --- /dev/null +++ b/host/src/server/api/util.rs @@ -0,0 +1,12 @@ +use crate::{ + interfaces::{HostError, HostResult}, + ProverState, +}; + +/// Ensure that the system is not paused, otherwise return an error. +pub fn ensure_not_paused(prover_state: &ProverState) -> HostResult<()> { + if prover_state.is_paused() { + return Err(HostError::SystemPaused); + } + Ok(()) +} diff --git a/host/src/server/api/v1/proof.rs b/host/src/server/api/v1/proof.rs index 572194fc..1437a2a0 100644 --- a/host/src/server/api/v1/proof.rs +++ b/host/src/server/api/v1/proof.rs @@ -8,7 +8,7 @@ use crate::{ interfaces::HostResult, metrics::{dec_current_req, inc_current_req, inc_guest_req_count, inc_host_req_count}, proof::handle_proof, - server::api::v1::Status, + server::api::{util::ensure_not_paused, v1::Status}, ProverState, }; @@ -35,6 +35,9 @@ async fn proof_handler( Json(req): Json, ) -> HostResult> { inc_current_req(); + + ensure_not_paused(&prover_state)?; + // Override the existing proof request config from the config file and command line // options with the request from the client. let mut config = prover_state.request_config(); diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index 3e52aa26..dfcc10e2 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -7,7 +7,7 @@ use utoipa::OpenApi; use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, - server::api::v2::Status, + server::api::{util::ensure_not_paused, v2::Status}, Message, ProverState, }; @@ -37,6 +37,9 @@ async fn proof_handler( Json(req): Json, ) -> HostResult { inc_current_req(); + + ensure_not_paused(&prover_state)?; + // Override the existing proof request config from the config file and command line // options with the request from the client. let mut config = prover_state.request_config(); diff --git a/host/src/server/api/v3/proof/aggregate/cancel.rs b/host/src/server/api/v3/proof/aggregate/cancel.rs index 6ed8ca18..d97aa745 100644 --- a/host/src/server/api/v3/proof/aggregate/cancel.rs +++ b/host/src/server/api/v3/proof/aggregate/cancel.rs @@ -65,7 +65,8 @@ async fn cancel_handler( | TaskStatus::GuestProverFailure(_) | TaskStatus::InvalidOrUnsupportedBlock | TaskStatus::UnspecifiedFailureReason - | TaskStatus::TaskDbCorruption(_) => { + | TaskStatus::TaskDbCorruption(_) + | TaskStatus::SystemPaused => { should_signal_cancel = true; CancelStatus::Error { error: "Task already completed".to_string(), diff --git a/host/src/server/api/v3/proof/mod.rs b/host/src/server/api/v3/proof/mod.rs index 7c57c5e5..e8668e77 100644 --- a/host/src/server/api/v3/proof/mod.rs +++ b/host/src/server/api/v3/proof/mod.rs @@ -1,6 +1,7 @@ use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::util::ensure_not_paused, server::api::{v2, v3::Status}, Message, ProverState, }; @@ -38,6 +39,9 @@ async fn proof_handler( Json(mut aggregation_request): Json, ) -> HostResult { inc_current_req(); + + ensure_not_paused(&prover_state)?; + // Override the existing proof request config from the config file and command line // options with the request from the client. aggregation_request.merge(&prover_state.request_config())?; @@ -230,3 +234,48 @@ pub fn create_router() -> Router { .nest("/list", v2::proof::list::create_router()) .nest("/prune", v2::proof::prune::create_router()) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Body, http::Request}; + use clap::Parser; + use std::path::PathBuf; + use tower::ServiceExt; + + #[tokio::test] + async fn test_proof_handler_when_paused() { + let opts = { + let mut opts = crate::Opts::parse(); + opts.config_path = PathBuf::from("../host/config/config.json"); + opts.merge_from_file().unwrap(); + opts + }; + let state = ProverState::init_with_opts(opts).unwrap(); + let app = Router::new() + .route("/", post(proof_handler)) + .with_state(state.clone()); + + // Set pause flag + state.set_pause(true).await.unwrap(); + + let request = Request::builder() + .method("POST") + .uri("/") + .header("content-type", "application/json") + .body(Body::from( + r#"{"block_numbers":[],"proof_type":"block","prover":"native"}"#, + )) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + let body = axum::body::to_bytes(response.into_body(), 1024) + .await + .unwrap(); + assert!( + String::from_utf8_lossy(&body).contains("system_paused"), + "{:?}", + body + ); + } +} diff --git a/taskdb/src/lib.rs b/taskdb/src/lib.rs index 71916d99..513843a1 100644 --- a/taskdb/src/lib.rs +++ b/taskdb/src/lib.rs @@ -71,6 +71,7 @@ pub enum TaskStatus { GuestProverFailure(String), UnspecifiedFailureReason, TaskDbCorruption(String), + SystemPaused, } impl From for i32 { @@ -92,6 +93,7 @@ impl From for i32 { TaskStatus::GuestProverFailure(_) => -7000, TaskStatus::UnspecifiedFailureReason => -8000, TaskStatus::TaskDbCorruption(_) => -9000, + TaskStatus::SystemPaused => -10000, } } } @@ -115,6 +117,7 @@ impl From for TaskStatus { -7000 => TaskStatus::GuestProverFailure("".to_string()), -8000 => TaskStatus::UnspecifiedFailureReason, -9000 => TaskStatus::TaskDbCorruption("".to_string()), + -10000 => TaskStatus::SystemPaused, _ => TaskStatus::UnspecifiedFailureReason, } }