Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(host): impl API "/admin/pause" #440

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions host/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ pub enum HostError {
/// For task manager errors.
#[error("There was an error with the task manager: {0}")]
TaskManager(#[from] TaskManagerError),

/// For system paused state.
#[error("System is paused")]
SystemPaused,
}

impl IntoResponse for HostError {
Expand All @@ -91,6 +95,7 @@ impl IntoResponse for HostError {
HostError::Anyhow(e) => ("anyhow_error", e.to_string()),
HostError::HandleDropped => ("handle_dropped", "".to_owned()),
HostError::CapacityFull => ("capacity_full", "".to_owned()),
HostError::SystemPaused => ("system_paused", "".to_owned()),
keroro520 marked this conversation as resolved.
Show resolved Hide resolved
};
let status = Status::Error {
error: error.to_owned(),
Expand Down Expand Up @@ -130,6 +135,7 @@ impl From<HostError> for TaskStatus {
HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()),
HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()),
HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()),
HostError::SystemPaused => TaskStatus::SystemPaused,
}
}
}
Expand All @@ -151,6 +157,7 @@ impl From<&HostError> for TaskStatus {
HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()),
HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()),
HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()),
HostError::SystemPaused => TaskStatus::SystemPaused,
}
}
}
39 changes: 35 additions & 4 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::{alloc, path::PathBuf};

use anyhow::Context;
Expand Down Expand Up @@ -108,10 +110,11 @@ impl Opts {

/// Read the options from a file and merge it with the current options.
pub fn merge_from_file(&mut self) -> HostResult<()> {
let file = std::fs::File::open(&self.config_path)?;
let file = std::fs::File::open(&self.config_path).context("Failed to open config file")?;
let reader = std::io::BufReader::new(file);
let mut config: Value = serde_json::from_reader(reader)?;
let this = serde_json::to_value(&self)?;
let mut config: Value =
serde_json::from_reader(reader).context("Failed to read config file")?;
let this = serde_json::to_value(&self).context("Failed to deserialize Opts")?;
merge(&mut config, &this);

*self = serde_json::from_value(config)?;
Expand Down Expand Up @@ -150,15 +153,17 @@ pub struct ProverState {
pub opts: Opts,
pub chain_specs: SupportedChainSpecs,
pub task_channel: mpsc::Sender<Message>,
pause_flag: Arc<AtomicBool>,
}

#[derive(Debug, Serialize)]
#[derive(Debug)]
pub enum Message {
Cancel(ProofTaskDescriptor),
Task(ProofRequest),
TaskComplete(ProofRequest),
CancelAggregate(AggregationOnlyRequest),
Aggregate(AggregationOnlyRequest),
SystemPause(tokio::sync::oneshot::Sender<HostResult<()>>),
}

impl ProverState {
Expand Down Expand Up @@ -188,6 +193,7 @@ impl ProverState {
}

let (task_channel, receiver) = mpsc::channel::<Message>(opts.concurrency_limit);
let pause_flag = Arc::new(AtomicBool::new(false));

let opts_clone = opts.clone();
let chain_specs_clone = chain_specs.clone();
Expand All @@ -202,6 +208,7 @@ impl ProverState {
opts,
chain_specs,
task_channel,
pause_flag,
})
}

Expand All @@ -212,6 +219,30 @@ impl ProverState {
pub fn request_config(&self) -> ProofRequestOpt {
self.opts.proof_request_opt.clone()
}

pub fn is_paused(&self) -> bool {
self.pause_flag.load(Ordering::SeqCst)
}

/// Set the pause flag and notify the task manager to pause, then wait for the task manager to
/// finish the pause process.
///
/// Note that this function is blocking until the task manager finishes the pause process.
pub async fn set_pause(&self, paused: bool) -> HostResult<()> {
self.pause_flag.store(paused, Ordering::SeqCst);
if paused {
// Notify task manager to start pause process
let (sender, receiver) = tokio::sync::oneshot::channel();
self.task_channel
.try_send(Message::SystemPause(sender))
.context("Failed to send pause message")?;

// Wait for the pause message to be processed
let result = receiver.await.context("Failed to receive pause message")?;
return result;
}
Ok(())
}
}

#[global_allocator]
Expand Down
253 changes: 253 additions & 0 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ impl ProofActor {
.expect("Couldn't acquire permit");
self.run_aggregate(request, permit).await;
}
Message::SystemPause(notifier) => {
let result = self.handle_system_pause().await;
let _ = notifier.send(result);
}
}
}
}
Expand Down Expand Up @@ -384,6 +388,90 @@ impl ProofActor {

Ok(())
}

async fn cancel_all_running_tasks(&mut self) -> HostResult<()> {
info!("Cancelling all running tasks");

// Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other
// internal functions.
let running_tasks = {
let running_tasks = self.running_tasks.lock().await;
(*running_tasks).clone()
};

// Cancel all running tasks, don't stop even if any task fails.
let mut final_result = Ok(());
for proof_task_descriptor in running_tasks.keys() {
match self.cancel_task(proof_task_descriptor.clone()).await {
Ok(()) => {
info!(
"Cancel task during system pause, task: {:?}",
proof_task_descriptor
);
}
Err(e) => {
error!(
"Failed to cancel task during system pause: {}, task: {:?}",
e, proof_task_descriptor
);
final_result = final_result.and(Err(e));
}
}
}
final_result
}

async fn cancel_all_aggregation_tasks(&mut self) -> HostResult<()> {
info!("Cancelling all aggregation tasks");

// Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other
// internal functions.
let aggregate_tasks = {
let aggregate_tasks = self.aggregate_tasks.lock().await;
(*aggregate_tasks).clone()
};

// Cancel all aggregation tasks, don't stop even if any task fails.
let mut final_result = Ok(());
for request in aggregate_tasks.keys() {
match self.cancel_aggregation_task(request.clone()).await {
Ok(()) => {
info!(
"Cancel aggregation task during system pause, task: {}",
request
);
}
Err(e) => {
error!(
"Failed to cancel aggregation task during system pause: {}, task: {}",
e, request
);
final_result = final_result.and(Err(e));
}
}
}
final_result
}

async fn handle_system_pause(&mut self) -> HostResult<()> {
info!("System pausing");

let mut final_result = Ok(());

self.pending_tasks.lock().await.clear();

if let Err(e) = self.cancel_all_running_tasks().await {
keroro520 marked this conversation as resolved.
Show resolved Hide resolved
final_result = final_result.and(Err(e));
}

if let Err(e) = self.cancel_all_aggregation_tasks().await {
final_result = final_result.and(Err(e));
}

// TODO(Kero): make sure all tasks are saved to database, including pending tasks.

final_result
}
}

pub async fn handle_proof(
Expand Down Expand Up @@ -483,3 +571,168 @@ pub async fn handle_proof(

Ok(proof)
}

#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;

#[tokio::test]
async fn test_handle_system_pause_happy_path() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

let result = actor.handle_system_pause().await;
assert!(result.is_ok());
}

#[tokio::test]
async fn test_handle_system_pause_with_pending_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some pending tasks
actor.pending_tasks.lock().await.push_back(ProofRequest {
block_number: 1,
l1_inclusion_block_number: 1,
network: "test".to_string(),
l1_network: "test".to_string(),
graffiti: B256::ZERO,
prover: Default::default(),
proof_type: Default::default(),
blob_proof_type: Default::default(),
prover_args: HashMap::new(),
});

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify pending tasks were cleared
assert_eq!(actor.pending_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_running_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some running tasks
let task_descriptor = ProofTaskDescriptor::default();
let cancellation_token = CancellationToken::new();
actor
.running_tasks
.lock()
.await
.insert(task_descriptor.clone(), cancellation_token.clone());

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify running tasks were cancelled
assert!(cancellation_token.is_cancelled());

// TODO(Kero): Cancelled tasks should be removed from running_tasks
// assert_eq!(actor.running_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_aggregation_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some aggregation tasks
let request = AggregationOnlyRequest::default();
let cancellation_token = CancellationToken::new();
actor
.aggregate_tasks
.lock()
.await
.insert(request.clone(), cancellation_token.clone());

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify aggregation tasks were cancelled
assert!(cancellation_token.is_cancelled());
// TODO(Kero): Cancelled tasks should be removed from aggregate_tasks
// assert_eq!(actor.aggregate_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_failures() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some pending tasks
{
actor.pending_tasks.lock().await.push_back(ProofRequest {
block_number: 1,
l1_inclusion_block_number: 1,
network: "test".to_string(),
l1_network: "test".to_string(),
graffiti: B256::ZERO,
prover: Default::default(),
proof_type: Default::default(),
blob_proof_type: Default::default(),
prover_args: HashMap::new(),
});
}

let good_running_task_token = {
// Add some running tasks
let task_descriptor = ProofTaskDescriptor::default();
let cancellation_token = CancellationToken::new();
actor
.running_tasks
.lock()
.await
.insert(task_descriptor.clone(), cancellation_token.clone());
cancellation_token
};

let good_aggregation_task_token = {
// Add some aggregation tasks
let request = AggregationOnlyRequest::default();
let cancellation_token = CancellationToken::new();
actor
.aggregate_tasks
.lock()
.await
.insert(request.clone(), cancellation_token.clone());
cancellation_token
};

// Setup tasks that will fail to cancel
{
let task_descriptor_should_fail_cause_not_supported_error = ProofTaskDescriptor {
proof_system: ProofType::Risc0,
..Default::default()
};
actor.running_tasks.lock().await.insert(
task_descriptor_should_fail_cause_not_supported_error.clone(),
CancellationToken::new(),
);
}

let result = actor.handle_system_pause().await;

// Verify error contains all accumulated errors
assert!(matches!(
result,
Err(HostError::Core(RaikoError::FeatureNotSupportedError(..)))
));
assert!(good_running_task_token.is_cancelled());
assert!(good_aggregation_task_token.is_cancelled());
assert!(actor.pending_tasks.lock().await.is_empty());
}

// Helper function to setup actor with common test configuration
fn setup_actor_with_tasks(tx: Sender<Message>, rx: Receiver<Message>) -> ProofActor {
let opts = Opts {
concurrency_limit: 4,
..Default::default()
};

ProofActor::new(tx, rx, opts, SupportedChainSpecs::default())
}
}
Loading
Loading