From 90e7e69b130cd2f5d2cfbcb7e0971855ed8c4238 Mon Sep 17 00:00:00 2001 From: Barak Date: Sat, 27 Apr 2024 09:30:20 +0300 Subject: [PATCH] test(concurrency): multithreaded flow test for the scheduler --- crates/blockifier/src/abi/sierra_types.rs | 11 +- .../blockifier/src/concurrency/flow_test.rs | 173 ++++++++++++++++++ .../blockifier/src/concurrency/scheduler.rs | 9 +- .../src/concurrency/scheduler_test.rs | 3 +- .../blockifier/src/concurrency/test_utils.rs | 8 + .../src/concurrency/versioned_state.rs | 5 + 6 files changed, 199 insertions(+), 10 deletions(-) create mode 100644 crates/blockifier/src/concurrency/flow_test.rs diff --git a/crates/blockifier/src/abi/sierra_types.rs b/crates/blockifier/src/abi/sierra_types.rs index 74893a1c55..8e48744a16 100644 --- a/crates/blockifier/src/abi/sierra_types.rs +++ b/crates/blockifier/src/abi/sierra_types.rs @@ -36,18 +36,23 @@ pub trait SierraType: Sized { fn from_memory(vm: &VirtualMachine, ptr: &mut Relocatable) -> SierraTypeResult; fn from_storage( - state: &mut dyn StateReader, + state: &dyn StateReader, contract_address: &ContractAddress, key: &StorageKey, ) -> SierraTypeResult; } // Utils. + pub fn felt_to_u128(felt: &Felt252) -> Result { felt.to_u128() .ok_or_else(|| SierraTypeError::ValueTooLargeForType { val: felt.clone(), ty: "u128" }) } +pub fn stark_felt_to_u128(stark_felt: &StarkFelt) -> Result { + felt_to_u128(&stark_felt_to_felt(*stark_felt)) +} + // TODO(barak, 01/10/2023): Move to starknet_api under StorageKey implementation. pub fn next_storage_key(key: &StorageKey) -> Result { Ok(StorageKey(PatriciaKey::try_from(StarkFelt::from( @@ -78,7 +83,7 @@ impl SierraType for SierraU128 { } fn from_storage( - state: &mut dyn StateReader, + state: &dyn StateReader, contract_address: &ContractAddress, key: &StorageKey, ) -> SierraTypeResult { @@ -111,7 +116,7 @@ impl SierraType for SierraU256 { } fn from_storage( - state: &mut dyn StateReader, + state: &dyn StateReader, contract_address: &ContractAddress, key: &StorageKey, ) -> SierraTypeResult { diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs new file mode 100644 index 0000000000..509cd02981 --- /dev/null +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -0,0 +1,173 @@ +use std::collections::HashMap; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use rstest::rstest; +use starknet_api::core::{ContractAddress, PatriciaKey}; +use starknet_api::hash::{StarkFelt, StarkHash}; +use starknet_api::{contract_address, patricia_key, stark_felt}; + +use crate::abi::sierra_types::{SierraType, SierraU128}; +use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus}; +use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE}; +use crate::concurrency::versioned_state::ThreadSafeVersionedState; +use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; +use crate::state::state_api::UpdatableState; +use crate::storage_key; +use crate::test_utils::dict_state_reader::DictStateReader; + +const CONTRACT_ADDRESS: &str = "0x18031991"; +const STORAGE_KEY: u8 = 27; + +#[rstest] +fn scheduler_flow_test( + // TODO(barak, 01/07/2024): Add a separate identical test and use the package loom. + #[values(1, 2, 4, 32, 64, 128)] num_threads: u8, +) { + // Tests the Scheduler under a heavy load of validation aborts. To do that, we simulate multiple + // transactions with multiple threads, where every transaction depends on its predecessor. Each + // transaction sequentially advances a counter by reading the previous value and bumping it by + // 1. + let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE)); + let versioned_state = + safe_versioned_state_for_testing(CachedState::from(DictStateReader::default())); + let mut handles = vec![]; + + for _ in 0..num_threads { + let scheduler = Arc::clone(&scheduler); + let versioned_state = versioned_state.clone(); + let handle = std::thread::spawn(move || { + let mut task = Task::NoTask; + loop { + if let Some(mut transaction_committer) = scheduler.try_enter_commit_phase() { + while let Some(tx_index) = transaction_committer.try_commit() { + let mut state_proxy = versioned_state.pin_version(tx_index); + let (reads, writes) = + get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); + let reads_valid = state_proxy.validate_reads(&reads); + if !reads_valid { + state_proxy.delete_writes(&writes, &ContractClassMapping::default()); + let (_, new_writes) = get_reads_writes_for( + Task::ExecutionTask(tx_index), + &versioned_state, + ); + state_proxy.apply_writes( + &new_writes, + &ContractClassMapping::default(), + &HashMap::default(), + ); + scheduler.finish_execution_during_commit(tx_index); + } + } + } + task = match task { + Task::ExecutionTask(tx_index) => { + let (_, writes) = + get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state); + versioned_state.pin_version(tx_index).apply_writes( + &writes, + &ContractClassMapping::default(), + &HashMap::default(), + ); + scheduler.finish_execution(tx_index); + Task::NoTask + } + Task::ValidationTask(tx_index) => { + let state_proxy = versioned_state.pin_version(tx_index); + let (reads, writes) = + get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); + let read_set_valid = state_proxy.validate_reads(&reads); + let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index); + if aborted { + state_proxy.delete_writes(&writes, &ContractClassMapping::default()); + scheduler.finish_abort(tx_index) + } else { + Task::NoTask + } + } + Task::NoTask => scheduler.next_task(), + Task::Done => break, + } + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + + // The execution index can be strictly greater than chunk_size. This is a side effect of using + // atomic variables instead of locks, which can encapsulate both the check (whether to increment + // the variable or not) and the incrementation in a scope where no other threads can access the + // variable. + assert!(scheduler.execution_index.load(Ordering::Acquire) >= DEFAULT_CHUNK_SIZE); + // There is no guarantee about the validation index because of the use of the commit index. + assert!(*scheduler.commit_index.lock().unwrap() == DEFAULT_CHUNK_SIZE); + assert!(scheduler.get_n_committed_txs() == DEFAULT_CHUNK_SIZE); + assert!(scheduler.done_marker.load(Ordering::Acquire)); + for tx_index in 0..DEFAULT_CHUNK_SIZE { + assert_eq!(*scheduler.tx_statuses[tx_index].lock().unwrap(), TransactionStatus::Committed); + let storage_writes = versioned_state.state().get_writes_of_index(tx_index).storage; + assert_eq!( + *storage_writes + .get(&(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY))) + .unwrap(), + stark_felt!(format!("{:x}", tx_index + 1).as_str()) + ); + } +} + +fn get_reads_writes_for( + task: Task, + versioned_state: &ThreadSafeVersionedState>, +) -> (StateMaps, StateMaps) { + match task { + Task::ExecutionTask(tx_index) => { + let state_proxy = match tx_index { + 0 => { + return ( + state_maps_with_single_storage_entry(0), + state_maps_with_single_storage_entry(1), + ); + } + _ => versioned_state.pin_version(tx_index - 1), + }; + let tx_written_value = SierraU128::from_storage( + &state_proxy, + &contract_address!(CONTRACT_ADDRESS), + &storage_key!(STORAGE_KEY), + ) + .unwrap() + .as_value(); + ( + state_maps_with_single_storage_entry(tx_written_value), + state_maps_with_single_storage_entry(tx_written_value + 1), + ) + } + Task::ValidationTask(tx_index) => { + let state_proxy = versioned_state.pin_version(tx_index); + let tx_written_value = SierraU128::from_storage( + &state_proxy, + &contract_address!(CONTRACT_ADDRESS), + &storage_key!(STORAGE_KEY), + ) + .unwrap() + .as_value(); + ( + state_maps_with_single_storage_entry(tx_written_value - 1), + state_maps_with_single_storage_entry(tx_written_value), + ) + } + _ => panic!("Only execution and validation tasks shold be used here."), + } +} + +fn state_maps_with_single_storage_entry(value: u128) -> StateMaps { + StateMaps { + storage: HashMap::from([( + (contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)), + stark_felt!(value), + )]), + ..Default::default() + } +} diff --git a/crates/blockifier/src/concurrency/scheduler.rs b/crates/blockifier/src/concurrency/scheduler.rs index 1aa01dc832..0e37a70e23 100644 --- a/crates/blockifier/src/concurrency/scheduler.rs +++ b/crates/blockifier/src/concurrency/scheduler.rs @@ -1,4 +1,3 @@ -use std::cmp::min; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Mutex, MutexGuard, TryLockError}; @@ -9,6 +8,10 @@ use crate::concurrency::TxIndex; #[path = "scheduler_test.rs"] pub mod test; +#[cfg(test)] +#[path = "flow_test.rs"] +pub mod flow_test; + pub struct TransactionCommitter<'a> { scheduler: &'a Scheduler, commit_index_guard: MutexGuard<'a, usize>, @@ -88,10 +91,6 @@ impl Scheduler { let index_to_validate = self.validation_index.load(Ordering::Acquire); let index_to_execute = self.execution_index.load(Ordering::Acquire); - if min(index_to_validate, index_to_execute) >= self.chunk_size { - return Task::NoTask; - } - if index_to_validate < index_to_execute { if let Some(tx_index) = self.next_version_to_validate() { return Task::ValidationTask(tx_index); diff --git a/crates/blockifier/src/concurrency/scheduler_test.rs b/crates/blockifier/src/concurrency/scheduler_test.rs index 92fcdecd4e..ff62d79bda 100644 --- a/crates/blockifier/src/concurrency/scheduler_test.rs +++ b/crates/blockifier/src/concurrency/scheduler_test.rs @@ -6,11 +6,10 @@ use pretty_assertions::assert_eq; use rstest::rstest; use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus}; +use crate::concurrency::test_utils::DEFAULT_CHUNK_SIZE; use crate::concurrency::TxIndex; use crate::default_scheduler; -const DEFAULT_CHUNK_SIZE: usize = 100; - #[rstest] fn test_new(#[values(0, 1, 32)] chunk_size: usize) { let scheduler = Scheduler::new(chunk_size); diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index c59e406c43..f8834334d6 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -12,6 +12,10 @@ use crate::test_utils::dict_state_reader::DictStateReader; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::transactions::ExecutableTransaction; +// Public Consts. + +pub const DEFAULT_CHUNK_SIZE: usize = 64; + // Fixtures. #[fixture] @@ -54,6 +58,8 @@ macro_rules! default_scheduler { }; } +// Concurrency constructors. + // TODO(meshi, 01/06/2024): Consider making this a macro. pub fn safe_versioned_state_for_testing( block_state: CachedState, @@ -61,6 +67,8 @@ pub fn safe_versioned_state_for_testing( ThreadSafeVersionedState::new(VersionedState::new(block_state)) } +// Utils. + // Note: this function does not mutate the state. pub fn create_fee_transfer_call_info( state: &mut CachedState, diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index 4b4be73ed0..796235be25 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -241,6 +241,11 @@ impl ThreadSafeVersionedState { .into_inner() .expect("No other mutex should hold the versioned state while calling this method.") } + + #[cfg(test)] + pub fn state(&self) -> LockedVersionedState<'_, S> { + self.0.lock().expect("Failed to acquire state lock.") + } } impl Clone for ThreadSafeVersionedState {