diff --git a/reqpool/Cargo.toml b/reqpool/Cargo.toml new file mode 100644 index 00000000..4c417367 --- /dev/null +++ b/reqpool/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "raiko-reqpool" +version = "0.1.0" +authors = ["Taiko Labs"] +edition = "2021" + +[dependencies] +raiko-lib = { workspace = true } +raiko-core = { workspace = true } +raiko-redis-derive = { workspace = true } +chrono = { workspace = true, features = ["serde"] } +serde = { workspace = true } +serde_json = { workspace = true } +serde_with = { workspace = true } +tracing = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +redis = { workspace = true } +backoff = { workspace = true } +derive-getters = { workspace = true } +proc-macro2 = { workspace = true } +quote = { workspace = true } +syn = { workspace = true } +alloy-primitives = { workspace = true } + +[dev-dependencies] +lazy_static = { workspace = true } diff --git a/reqpool/src/config.rs b/reqpool/src/config.rs new file mode 100644 index 00000000..0050daa3 --- /dev/null +++ b/reqpool/src/config.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// The configuration for the redis-backend request pool +pub struct RedisPoolConfig { + /// The URL of the Redis database, e.g. "redis://localhost:6379" + pub redis_url: String, + /// The TTL of the Redis database + pub redis_ttl: u64, +} diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs new file mode 100644 index 00000000..e2502116 --- /dev/null +++ b/reqpool/src/lib.rs @@ -0,0 +1,16 @@ +mod config; +mod macros; +#[cfg(any(test, feature = "enable-mock"))] +mod mock; +mod redis_pool; +mod request; +mod utils; + +// Re-export +pub use config::RedisPoolConfig; +pub use redis_pool::Pool; +pub use request::{ + AggregationRequestEntity, AggregationRequestKey, RequestEntity, RequestKey, + SingleProofRequestEntity, SingleProofRequestKey, Status, StatusWithContext, +}; +pub use utils::proof_key_to_hack_request_key; diff --git a/reqpool/src/macros.rs b/reqpool/src/macros.rs new file mode 100644 index 00000000..fb36b349 --- /dev/null +++ b/reqpool/src/macros.rs @@ -0,0 +1,44 @@ +/// This macro implements the Display trait for a type by using serde_json's pretty printing. +/// If the type cannot be serialized to JSON, it falls back to using Debug formatting. +/// +/// # Example +/// +/// ```rust +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Debug, Serialize, Deserialize)] +/// struct Person { +/// name: String, +/// age: u32 +/// } +/// +/// impl_display_using_json_pretty!(Person); +/// +/// let person = Person { +/// name: "John".to_string(), +/// age: 30 +/// }; +/// +/// // Will print: +/// // { +/// // "name": "John", +/// // "age": 30 +/// // } +/// println!("{}", person); +/// ``` +/// +/// The type must implement serde's Serialize trait for JSON serialization to work. +/// If serialization fails, it will fall back to using the Debug implementation. +#[macro_export] +macro_rules! impl_display_using_json_pretty { + ($type:ty) => { + impl std::fmt::Display for $type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match serde_json::to_string_pretty(self) { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, "{:?}", self), + } + } + } + }; +} diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs new file mode 100644 index 00000000..27194e38 --- /dev/null +++ b/reqpool/src/mock.rs @@ -0,0 +1,145 @@ +use lazy_static::lazy_static; +use redis::{RedisError, RedisResult}; +use serde::Serialize; +use serde_json::{json, Value}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +type SingleStorage = Arc>>; +type GlobalStorage = Mutex>; + +lazy_static! { + // #{redis_url => single_storage} + // + // We use redis_url to distinguish different redis database for tests, to prevent + // data race problem when running multiple tests. + static ref GLOBAL_STORAGE: GlobalStorage = Mutex::new(HashMap::new()); +} + +pub struct MockRedisConnection { + storage: SingleStorage, +} + +impl MockRedisConnection { + pub(crate) fn new(redis_url: String) -> Self { + let mut global = GLOBAL_STORAGE.lock().unwrap(); + Self { + storage: global + .entry(redis_url) + .or_insert_with(|| Arc::new(Mutex::new(HashMap::new()))) + .clone(), + } + } + + pub fn set_ex( + &mut self, + key: K, + val: V, + _ttl: u64, + ) -> RedisResult<()> { + let mut lock = self.storage.lock().unwrap(); + lock.insert(json!(key), json!(val)); + Ok(()) + } + + pub fn get(&mut self, key: &K) -> RedisResult { + let lock = self.storage.lock().unwrap(); + match lock.get(&json!(key)) { + None => Err(RedisError::from((redis::ErrorKind::TypeError, "not found"))), + Some(v) => serde_json::from_value(v.clone()).map_err(|e| { + RedisError::from(( + redis::ErrorKind::TypeError, + "deserialization error", + e.to_string(), + )) + }), + } + } + + pub fn del(&mut self, key: K) -> RedisResult { + let mut lock = self.storage.lock().unwrap(); + if lock.remove(&json!(key)).is_none() { + Ok(0) + } else { + Ok(1) + } + } +} + +#[cfg(test)] +mod tests { + use redis::RedisResult; + + use crate::{Pool, RedisPoolConfig}; + + #[test] + fn test_mock_redis_pool() { + let config = RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6379".to_string(), + }; + let mut pool = Pool::open(config).unwrap(); + let mut conn = pool.conn().expect("mock conn"); + + let key = "hello".to_string(); + let val = "world".to_string(); + conn.set_ex(key.clone(), val.clone(), 111) + .expect("mock set_ex"); + + let actual: RedisResult = conn.get(&key); + assert_eq!(actual, Ok(val)); + + let _ = conn.del(&key); + let actual: RedisResult = conn.get(&key); + assert!(actual.is_err()); + } + + #[test] + fn test_mock_multiple_redis_pool() { + let mut pool1 = Pool::open(RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6379".to_string(), + }) + .unwrap(); + let mut pool2 = Pool::open(RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6380".to_string(), + }) + .unwrap(); + + let mut conn1 = pool1.conn().expect("mock conn"); + let mut conn2 = pool2.conn().expect("mock conn"); + + let key = "hello".to_string(); + let world = "world".to_string(); + + { + conn1 + .set_ex(key.clone(), world.clone(), 111) + .expect("mock set_ex"); + let actual: RedisResult = conn1.get(&key); + assert_eq!(actual, Ok(world.clone())); + } + + { + let actual: RedisResult = conn2.get(&key); + assert!(actual.is_err()); + } + + { + let meme = "meme".to_string(); + conn2 + .set_ex(key.clone(), meme.clone(), 111) + .expect("mock set_ex"); + let actual: RedisResult = conn2.get(&key); + assert_eq!(actual, Ok(meme)); + } + + { + let actual: RedisResult = conn1.get(&key); + assert_eq!(actual, Ok(world)); + } + } +} diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs new file mode 100644 index 00000000..62f29f01 --- /dev/null +++ b/reqpool/src/redis_pool.rs @@ -0,0 +1,208 @@ +use crate::{ + impl_display_using_json_pretty, proof_key_to_hack_request_key, RedisPoolConfig, RequestEntity, + RequestKey, StatusWithContext, +}; +use backoff::{exponential::ExponentialBackoff, SystemClock}; +use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; +use raiko_redis_derive::RedisValue; +#[allow(unused_imports)] +use redis::{Client, Commands, RedisResult}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct Pool { + client: Client, + config: RedisPoolConfig, +} + +impl Pool { + pub fn add( + &mut self, + request_key: RequestKey, + request_entity: RequestEntity, + status: StatusWithContext, + ) -> Result<(), String> { + tracing::info!("RedisPool.add: {request_key}, {status}"); + let request_entity_and_status = RequestEntityAndStatus { + entity: request_entity, + status, + }; + self.conn() + .map_err(|e| e.to_string())? + .set_ex( + request_key, + request_entity_and_status, + self.config.redis_ttl, + ) + .map_err(|e| e.to_string())?; + Ok(()) + } + + pub fn remove(&mut self, request_key: &RequestKey) -> Result { + tracing::info!("RedisPool.remove: {request_key}"); + let result: usize = self + .conn() + .map_err(|e| e.to_string())? + .del(request_key) + .map_err(|e| e.to_string())?; + Ok(result) + } + + pub fn get( + &mut self, + request_key: &RequestKey, + ) -> Result, String> { + let result: RedisResult = + self.conn().map_err(|e| e.to_string())?.get(request_key); + match result { + Ok(value) => Ok(Some(value.into())), + Err(e) if e.kind() == redis::ErrorKind::TypeError => Ok(None), + Err(e) => Err(e.to_string()), + } + } + + pub fn get_status( + &mut self, + request_key: &RequestKey, + ) -> Result, String> { + self.get(request_key).map(|v| v.map(|v| v.1)) + } + + pub fn update_status( + &mut self, + request_key: RequestKey, + status: StatusWithContext, + ) -> Result { + tracing::info!("RedisPool.update_status: {request_key}, {status}"); + match self.get(&request_key)? { + Some((entity, old_status)) => { + self.add(request_key, entity, status)?; + Ok(old_status) + } + None => Err("Request not found".to_string()), + } + } +} + +#[async_trait::async_trait] +impl IdStore for Pool { + async fn read_id(&mut self, proof_key: ProofKey) -> ProverResult { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.read_id: {hack_request_key}"); + + let result: RedisResult = self + .conn() + .map_err(|e| e.to_string())? + .get(&hack_request_key); + match result { + Ok(value) => Ok(value.into()), + Err(e) => Err(ProverError::StoreError(e.to_string())), + } + } +} + +#[async_trait::async_trait] +impl IdWrite for Pool { + async fn store_id(&mut self, proof_key: ProofKey, id: String) -> ProverResult<()> { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.store_id: {hack_request_key}, {id}"); + + self.conn() + .map_err(|e| e.to_string())? + .set_ex(hack_request_key, id, self.config.redis_ttl) + .map_err(|e| ProverError::StoreError(e.to_string()))?; + Ok(()) + } + + async fn remove_id(&mut self, proof_key: ProofKey) -> ProverResult<()> { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.remove_id: {hack_request_key}"); + + self.conn() + .map_err(|e| e.to_string())? + .del(hack_request_key) + .map_err(|e| ProverError::StoreError(e.to_string()))?; + Ok(()) + } +} + +impl Pool { + pub fn open(config: RedisPoolConfig) -> Result { + tracing::info!("RedisPool.open: connecting to redis: {}", config.redis_url); + + let client = Client::open(config.redis_url.clone())?; + Ok(Self { client, config }) + } + + #[cfg(any(test, feature = "enable-mock"))] + pub(crate) fn conn(&mut self) -> Result { + Ok(crate::mock::MockRedisConnection::new( + self.config.redis_url.clone(), + )) + } + + #[cfg(not(any(test, feature = "enable-mock")))] + fn conn(&mut self) -> Result { + self.redis_conn() + } + + #[allow(dead_code)] + fn redis_conn(&mut self) -> Result { + let backoff: ExponentialBackoff = ExponentialBackoff { + initial_interval: Duration::from_secs(10), + max_interval: Duration::from_secs(60), + max_elapsed_time: Some(Duration::from_secs(300)), + ..Default::default() + }; + + backoff::retry(backoff, || match self.client.get_connection() { + Ok(conn) => Ok(conn), + Err(e) => { + tracing::error!( + "RedisPool.get_connection: failed to connect to redis: {e:?}, retrying..." + ); + + self.client = redis::Client::open(self.config.redis_url.clone())?; + Err(backoff::Error::Transient { + err: e, + retry_after: None, + }) + } + }) + .map_err(|e| match e { + backoff::Error::Transient { + err, + retry_after: _, + } + | backoff::Error::Permanent(err) => err, + }) + } +} + +/// A internal wrapper for request entity and status, used for redis serialization +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)] +struct RequestEntityAndStatus { + entity: RequestEntity, + status: StatusWithContext, +} + +impl From<(RequestEntity, StatusWithContext)> for RequestEntityAndStatus { + fn from(value: (RequestEntity, StatusWithContext)) -> Self { + Self { + entity: value.0, + status: value.1, + } + } +} + +impl From for (RequestEntity, StatusWithContext) { + fn from(value: RequestEntityAndStatus) -> Self { + (value.entity, value.status) + } +} + +impl_display_using_json_pretty!(RequestEntityAndStatus); diff --git a/reqpool/src/request.rs b/reqpool/src/request.rs new file mode 100644 index 00000000..f02cd3e2 --- /dev/null +++ b/reqpool/src/request.rs @@ -0,0 +1,304 @@ +use crate::impl_display_using_json_pretty; +use alloy_primitives::Address; +use chrono::{DateTime, Utc}; +use derive_getters::Getters; +use raiko_core::interfaces::ProverSpecificOpts; +use raiko_lib::{ + input::BlobProofType, + primitives::{ChainId, B256}, + proof_type::ProofType, + prover::Proof, +}; +use raiko_redis_derive::RedisValue; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DisplayFromStr}; +use std::collections::HashMap; + +#[derive(RedisValue, PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord)] +#[serde(rename_all = "snake_case")] +/// The status of a request +pub enum Status { + // === Normal status === + /// The request is registered but not yet started + Registered, + + /// The request is in progress + WorkInProgress, + + // /// The request is in progress of proving + // WorkInProgressProving { + // /// The proof ID + // /// For SP1 and RISC0 proof type, it is the proof ID returned by the network prover, + // /// otherwise, it should be empty. + // proof_id: String, + // }, + /// The request is successful + Success { + /// The proof of the request + proof: Proof, + }, + + // === Cancelled status === + /// The request is cancelled + Cancelled, + + // === Error status === + /// The request is failed with an error + Failed { + /// The error message + error: String, + }, +} + +impl Status { + pub fn is_success(&self) -> bool { + matches!(self, Status::Success { .. }) + } +} + +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, RedisValue, Getters, +)] +/// The status of a request with context +pub struct StatusWithContext { + /// The status of the request + status: Status, + /// The timestamp of the status + timestamp: DateTime, +} + +impl StatusWithContext { + pub fn new(status: Status, timestamp: DateTime) -> Self { + Self { status, timestamp } + } + + pub fn new_registered() -> Self { + Self::new(Status::Registered, chrono::Utc::now()) + } + + pub fn new_cancelled() -> Self { + Self::new(Status::Cancelled, chrono::Utc::now()) + } + + pub fn into_status(self) -> Status { + self.status + } +} + +impl From for StatusWithContext { + fn from(status: Status) -> Self { + Self::new(status, chrono::Utc::now()) + } +} + +/// The key to identify a request in the pool +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, +)] +pub enum RequestKey { + SingleProof(SingleProofRequestKey), + Aggregation(AggregationRequestKey), +} + +impl RequestKey { + pub fn proof_type(&self) -> &ProofType { + match self { + RequestKey::SingleProof(key) => &key.proof_type, + RequestKey::Aggregation(key) => &key.proof_type, + } + } +} + +/// The key to identify a request in the pool +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters, +)] +pub struct SingleProofRequestKey { + /// The chain ID of the request + chain_id: ChainId, + /// The block number of the request + block_number: u64, + /// The block hash of the request + block_hash: B256, + /// The proof type of the request + proof_type: ProofType, + /// The prover of the request + prover_address: String, +} + +impl SingleProofRequestKey { + pub fn new( + chain_id: ChainId, + block_number: u64, + block_hash: B256, + proof_type: ProofType, + prover_address: String, + ) -> Self { + Self { + chain_id, + block_number, + block_hash, + proof_type, + prover_address, + } + } +} + +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters, +)] +/// The key to identify an aggregation request in the pool +pub struct AggregationRequestKey { + // TODO add chain_id + proof_type: ProofType, + block_numbers: Vec, +} + +impl AggregationRequestKey { + pub fn new(proof_type: ProofType, block_numbers: Vec) -> Self { + Self { + proof_type, + block_numbers, + } + } +} + +impl From for RequestKey { + fn from(key: SingleProofRequestKey) -> Self { + RequestKey::SingleProof(key) + } +} + +impl From for RequestKey { + fn from(key: AggregationRequestKey) -> Self { + RequestKey::Aggregation(key) + } +} + +#[serde_as] +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)] +pub struct SingleProofRequestEntity { + /// The block number for the block to generate a proof for. + block_number: u64, + /// The l1 block number of the l2 block be proposed. + l1_inclusion_block_number: u64, + /// The network to generate the proof for. + network: String, + /// The L1 network to generate the proof for. + l1_network: String, + /// Graffiti. + graffiti: B256, + /// The protocol instance data. + #[serde_as(as = "DisplayFromStr")] + prover: Address, + /// The proof type. + proof_type: ProofType, + /// Blob proof type. + blob_proof_type: BlobProofType, + #[serde(flatten)] + /// Additional prover params. + prover_args: HashMap, +} + +impl SingleProofRequestEntity { + pub fn new( + block_number: u64, + l1_inclusion_block_number: u64, + network: String, + l1_network: String, + graffiti: B256, + prover: Address, + proof_type: ProofType, + blob_proof_type: BlobProofType, + prover_args: HashMap, + ) -> Self { + Self { + block_number, + l1_inclusion_block_number, + network, + l1_network, + graffiti, + prover, + proof_type, + blob_proof_type, + prover_args, + } + } +} + +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)] +pub struct AggregationRequestEntity { + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + aggregation_ids: Vec, + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + proofs: Vec, + /// The proof type. + proof_type: ProofType, + #[serde(flatten)] + /// Any additional prover params in JSON format. + prover_args: ProverSpecificOpts, +} + +impl AggregationRequestEntity { + pub fn new( + aggregation_ids: Vec, + proofs: Vec, + proof_type: ProofType, + prover_args: ProverSpecificOpts, + ) -> Self { + Self { + aggregation_ids, + proofs, + proof_type, + prover_args, + } + } +} + +/// The entity of a request +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)] +pub enum RequestEntity { + SingleProof(SingleProofRequestEntity), + Aggregation(AggregationRequestEntity), +} + +impl From for RequestEntity { + fn from(entity: SingleProofRequestEntity) -> Self { + RequestEntity::SingleProof(entity) + } +} + +impl From for RequestEntity { + fn from(entity: AggregationRequestEntity) -> Self { + RequestEntity::Aggregation(entity) + } +} + +// === impl Display using json_pretty === + +impl_display_using_json_pretty!(RequestKey); +impl_display_using_json_pretty!(SingleProofRequestKey); +impl_display_using_json_pretty!(AggregationRequestKey); +impl_display_using_json_pretty!(RequestEntity); +impl_display_using_json_pretty!(SingleProofRequestEntity); +impl_display_using_json_pretty!(AggregationRequestEntity); + +// === impl Display for Status === + +impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Status::Registered => write!(f, "Registered"), + Status::WorkInProgress => write!(f, "WorkInProgress"), + Status::Success { .. } => write!(f, "Success"), + Status::Cancelled => write!(f, "Cancelled"), + Status::Failed { error } => write!(f, "Failed({})", error), + } + } +} + +impl std::fmt::Display for StatusWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.status()) + } +} diff --git a/reqpool/src/utils.rs b/reqpool/src/utils.rs new file mode 100644 index 00000000..ba9d5c3d --- /dev/null +++ b/reqpool/src/utils.rs @@ -0,0 +1,27 @@ +use raiko_lib::{proof_type::ProofType, prover::ProofKey}; + +use crate::{RequestKey, SingleProofRequestKey}; + +/// Returns the proof key corresponding to the request key. +/// +/// During proving, the prover will store the network proof id into pool, which is identified by **proof key**. This +/// function is used to generate a unique proof key corresponding to the request key, so that we can store the +/// proof key into the pool. +/// +/// Note that this is a hack, and it should be removed in the future. +pub fn proof_key_to_hack_request_key(proof_key: ProofKey) -> RequestKey { + let (chain_id, block_number, block_hash, proof_type) = proof_key; + + // HACK: Use a special prover address as a mask, to distinguish from real + // RequestKeys + let hack_prover_address = String::from("0x1231231231231231231231231231231231231231"); + + SingleProofRequestKey::new( + chain_id, + block_number, + block_hash, + ProofType::try_from(proof_type).expect("unsupported proof type, it should not happen at proof_key_to_hack_request_key, please issue a bug report"), + hack_prover_address, + ) + .into() +}