From a428c02701d4c2bcbca5d879642d05da8a03befd Mon Sep 17 00:00:00 2001 From: Barak Date: Sat, 27 Apr 2024 09:30:20 +0300 Subject: [PATCH] test(concurrency): multithreaded test for decrease_validation_index --- .../blockifier/src/concurrency/flow_test.rs | 102 ++++++++++++++++++ .../blockifier/src/concurrency/scheduler.rs | 8 +- .../src/concurrency/versioned_state_proxy.rs | 16 +++ .../src/concurrency/versioned_storage.rs | 5 + 4 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 crates/blockifier/src/concurrency/flow_test.rs diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs new file mode 100644 index 0000000000..d4de23391d --- /dev/null +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use rstest::rstest; +use starknet_api::core::{ContractAddress, PatriciaKey}; +use starknet_api::hash::{StarkFelt, StarkHash}; +use starknet_api::{contract_address, patricia_key, stark_felt}; + +use crate::concurrency::scheduler::{Scheduler, Task}; +use crate::concurrency::test_utils::safe_versioned_state_for_testing; +use crate::state::cached_state::{ContractClassMapping, StateMaps}; +use crate::state::state_api::StateReader; +use crate::storage_key; +use crate::test_utils::dict_state_reader::DictStateReader; + +const DEFAULT_CHUNK_SIZE: usize = 64; + +#[rstest] +fn aaaflow_test() { + // Simulate DEFAULT_CHUNK_SIZE txs. Each reads (CONTRACT_ADDRESS, STORAGE_KEY) and writes its tx + // index inside the chunk to the same storage cell. + let scheduler = Arc::new(Scheduler::new(4)); + let versioned_state = safe_versioned_state_for_testing(DictStateReader { + storage_view: HashMap::from([( + (contract_address!("0x4"), storage_key!(27_u8)), + stark_felt!(128_u8), + )]), + ..Default::default() + }); + let mut handles = vec![]; + + let num_threads = 2_u8.pow(2); + for _ in 0..num_threads { + let scheduler = Arc::clone(&scheduler); + let versioned_state = versioned_state.clone(); + let handle = std::thread::spawn(move || { + let mut task = Task::NoTask; + while !scheduler.done() { + match task { + Task::ExecutionTask(tx_index) => { + let storage_value: u8 = tx_index.try_into().unwrap(); + let versioned_state_proxy = versioned_state.pin_version(tx_index); + // Simulate read to modify the cached_initial_values. + versioned_state_proxy + .get_storage_at(contract_address!("0x4"), storage_key!(27_u8)) + .unwrap(); + let write_set = StateMaps { + storage: HashMap::from([( + (contract_address!("0x4"), storage_key!(27_u8)), + stark_felt!(storage_value), + )]), + ..Default::default() + }; + versioned_state_proxy + .apply_writes(&write_set, &ContractClassMapping::default()); + scheduler.finish_execution(tx_index); + task = Task::NoTask; + } + Task::ValidationTask(tx_index) => { + let versioned_state_proxy = versioned_state.pin_version(tx_index); + let current_cell_value = versioned_state_proxy + .get_storage_at(contract_address!("0x4"), storage_key!(27_u8)) + .unwrap(); + let read_set = StateMaps { + storage: HashMap::from([( + (contract_address!("0x4"), storage_key!(27_u8)), + stark_felt!(current_cell_value), + )]), + ..Default::default() + }; + let read_set_valid = versioned_state_proxy.validate_read_set(&read_set); + let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index); + if aborted { + versioned_state_proxy.delete_writes(tx_index) + } + task = scheduler.finish_validation(tx_index, aborted); + } + Task::NoTask => { + task = scheduler.next_task(); + } + Task::Done => (), + } + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let storage_writes = versioned_state.state().get_writes(DEFAULT_CHUNK_SIZE).storage; + let storage_ini_vals = versioned_state.state().get_cached_initial_values().storage; + assert_eq!( + *storage_writes.get(&(contract_address!("0x4"), storage_key!(27_u8))).unwrap(), + stark_felt!(3_u8) + ); + assert_eq!( + *storage_ini_vals.get(&(contract_address!("0x4"), storage_key!(27_u8))).unwrap(), + stark_felt!(128_u8) + ); + dbg!(storage_writes.get(&(contract_address!("0x4"), storage_key!(27_u8))).unwrap()); + dbg!(storage_ini_vals.get(&(contract_address!("0x4"), storage_key!(27_u8))).unwrap()); +} diff --git a/crates/blockifier/src/concurrency/scheduler.rs b/crates/blockifier/src/concurrency/scheduler.rs index d8f6a2a485..565c3cece8 100644 --- a/crates/blockifier/src/concurrency/scheduler.rs +++ b/crates/blockifier/src/concurrency/scheduler.rs @@ -8,6 +8,10 @@ use crate::concurrency::TxIndex; #[path = "scheduler_test.rs"] pub mod test; +#[cfg(test)] +#[path = "flow_test.rs"] +pub mod flow_test; + // TODO(Avi, 01/04/2024): Remove dead_code attribute. #[allow(dead_code)] #[derive(Debug, Default)] @@ -57,10 +61,6 @@ impl Scheduler { let index_to_validate = self.validation_index.load(Ordering::Acquire); let index_to_execute = self.execution_index.load(Ordering::Acquire); - if min(index_to_validate, index_to_execute) >= self.chunk_size { - return Task::NoTask; - } - if index_to_validate < index_to_execute { if let Some(tx_index) = self.next_version_to_validate() { return Task::ValidationTask(tx_index); diff --git a/crates/blockifier/src/concurrency/versioned_state_proxy.rs b/crates/blockifier/src/concurrency/versioned_state_proxy.rs index 105ec77d7a..76697a1518 100644 --- a/crates/blockifier/src/concurrency/versioned_state_proxy.rs +++ b/crates/blockifier/src/concurrency/versioned_state_proxy.rs @@ -52,6 +52,17 @@ impl VersionedState { } } + #[cfg(test)] + pub fn get_cached_initial_values(&self) -> StateMaps { + StateMaps { + storage: self.storage.get_cached_initial_values(), + nonces: self.nonces.get_cached_initial_values(), + class_hashes: self.class_hashes.get_cached_initial_values(), + compiled_class_hashes: self.compiled_class_hashes.get_cached_initial_values(), + declared_contracts: HashMap::new(), + } + } + pub fn commit(&mut self, from_index: TxIndex, parent_state: &mut CachedState) where T: StateReader, @@ -160,6 +171,11 @@ impl ThreadSafeVersionedState { pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { VersionedStateProxy { tx_index, state: self.0.clone() } } + + #[cfg(test)] + pub fn state(&self) -> LockedVersionedState<'_, S> { + self.0.lock().expect("Failed to acquire state lock.") + } } impl Clone for ThreadSafeVersionedState { diff --git a/crates/blockifier/src/concurrency/versioned_storage.rs b/crates/blockifier/src/concurrency/versioned_storage.rs index 34a35ad3a6..f5a39c4a11 100644 --- a/crates/blockifier/src/concurrency/versioned_storage.rs +++ b/crates/blockifier/src/concurrency/versioned_storage.rs @@ -73,4 +73,9 @@ where } writes } + + #[cfg(test)] + pub fn get_cached_initial_values(&self) -> HashMap { + self.cached_initial_values.clone() + } }