Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
test(concurrency): multithreaded flow test for the scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
barak-b-starkware committed Jun 19, 2024
1 parent 5129ccd commit 90e7e69
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 10 deletions.
11 changes: 8 additions & 3 deletions crates/blockifier/src/abi/sierra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@ pub trait SierraType: Sized {
fn from_memory(vm: &VirtualMachine, ptr: &mut Relocatable) -> SierraTypeResult<Self>;

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self>;
}

// Utils.

pub fn felt_to_u128(felt: &Felt252) -> Result<u128, SierraTypeError> {
felt.to_u128()
.ok_or_else(|| SierraTypeError::ValueTooLargeForType { val: felt.clone(), ty: "u128" })
}

pub fn stark_felt_to_u128(stark_felt: &StarkFelt) -> Result<u128, SierraTypeError> {
felt_to_u128(&stark_felt_to_felt(*stark_felt))
}

// TODO(barak, 01/10/2023): Move to starknet_api under StorageKey implementation.
pub fn next_storage_key(key: &StorageKey) -> Result<StorageKey, StarknetApiError> {
Ok(StorageKey(PatriciaKey::try_from(StarkFelt::from(
Expand Down Expand Up @@ -78,7 +83,7 @@ impl SierraType for SierraU128 {
}

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self> {
Expand Down Expand Up @@ -111,7 +116,7 @@ impl SierraType for SierraU256 {
}

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self> {
Expand Down
173 changes: 173 additions & 0 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use rstest::rstest;
use starknet_api::core::{ContractAddress, PatriciaKey};
use starknet_api::hash::{StarkFelt, StarkHash};
use starknet_api::{contract_address, patricia_key, stark_felt};

use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::storage_key;
use crate::test_utils::dict_state_reader::DictStateReader;

const CONTRACT_ADDRESS: &str = "0x18031991";
const STORAGE_KEY: u8 = 27;

#[rstest]
fn scheduler_flow_test(
// TODO(barak, 01/07/2024): Add a separate identical test and use the package loom.
#[values(1, 2, 4, 32, 64, 128)] num_threads: u8,
) {
// Tests the Scheduler under a heavy load of validation aborts. To do that, we simulate multiple
// transactions with multiple threads, where every transaction depends on its predecessor. Each
// transaction sequentially advances a counter by reading the previous value and bumping it by
// 1.
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
let versioned_state =
safe_versioned_state_for_testing(CachedState::from(DictStateReader::default()));
let mut handles = vec![];

for _ in 0..num_threads {
let scheduler = Arc::clone(&scheduler);
let versioned_state = versioned_state.clone();
let handle = std::thread::spawn(move || {
let mut task = Task::NoTask;
loop {
if let Some(mut transaction_committer) = scheduler.try_enter_commit_phase() {
while let Some(tx_index) = transaction_committer.try_commit() {
let mut state_proxy = versioned_state.pin_version(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let reads_valid = state_proxy.validate_reads(&reads);
if !reads_valid {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
let (_, new_writes) = get_reads_writes_for(
Task::ExecutionTask(tx_index),
&versioned_state,
);
state_proxy.apply_writes(
&new_writes,
&ContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution_during_commit(tx_index);
}
}
}
task = match task {
Task::ExecutionTask(tx_index) => {
let (_, writes) =
get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state);
versioned_state.pin_version(tx_index).apply_writes(
&writes,
&ContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution(tx_index);
Task::NoTask
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let read_set_valid = state_proxy.validate_reads(&reads);
let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index);
if aborted {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
scheduler.finish_abort(tx_index)
} else {
Task::NoTask
}
}
Task::NoTask => scheduler.next_task(),
Task::Done => break,
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}

// The execution index can be strictly greater than chunk_size. This is a side effect of using
// atomic variables instead of locks, which can encapsulate both the check (whether to increment
// the variable or not) and the incrementation in a scope where no other threads can access the
// variable.
assert!(scheduler.execution_index.load(Ordering::Acquire) >= DEFAULT_CHUNK_SIZE);
// There is no guarantee about the validation index because of the use of the commit index.
assert!(*scheduler.commit_index.lock().unwrap() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.get_n_committed_txs() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.done_marker.load(Ordering::Acquire));
for tx_index in 0..DEFAULT_CHUNK_SIZE {
assert_eq!(*scheduler.tx_statuses[tx_index].lock().unwrap(), TransactionStatus::Committed);
let storage_writes = versioned_state.state().get_writes_of_index(tx_index).storage;
assert_eq!(
*storage_writes
.get(&(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)))
.unwrap(),
stark_felt!(format!("{:x}", tx_index + 1).as_str())
);
}
}

fn get_reads_writes_for(
task: Task,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader>>,
) -> (StateMaps, StateMaps) {
match task {
Task::ExecutionTask(tx_index) => {
let state_proxy = match tx_index {
0 => {
return (
state_maps_with_single_storage_entry(0),
state_maps_with_single_storage_entry(1),
);
}
_ => versioned_state.pin_version(tx_index - 1),
};
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value),
state_maps_with_single_storage_entry(tx_written_value + 1),
)
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value - 1),
state_maps_with_single_storage_entry(tx_written_value),
)
}
_ => panic!("Only execution and validation tasks shold be used here."),
}
}

fn state_maps_with_single_storage_entry(value: u128) -> StateMaps {
StateMaps {
storage: HashMap::from([(
(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)),
stark_felt!(value),
)]),
..Default::default()
}
}
9 changes: 4 additions & 5 deletions crates/blockifier/src/concurrency/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::cmp::min;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Mutex, MutexGuard, TryLockError};

Expand All @@ -9,6 +8,10 @@ use crate::concurrency::TxIndex;
#[path = "scheduler_test.rs"]
pub mod test;

#[cfg(test)]
#[path = "flow_test.rs"]
pub mod flow_test;

pub struct TransactionCommitter<'a> {
scheduler: &'a Scheduler,
commit_index_guard: MutexGuard<'a, usize>,
Expand Down Expand Up @@ -88,10 +91,6 @@ impl Scheduler {
let index_to_validate = self.validation_index.load(Ordering::Acquire);
let index_to_execute = self.execution_index.load(Ordering::Acquire);

if min(index_to_validate, index_to_execute) >= self.chunk_size {
return Task::NoTask;
}

if index_to_validate < index_to_execute {
if let Some(tx_index) = self.next_version_to_validate() {
return Task::ValidationTask(tx_index);
Expand Down
3 changes: 1 addition & 2 deletions crates/blockifier/src/concurrency/scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use pretty_assertions::assert_eq;
use rstest::rstest;

use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::DEFAULT_CHUNK_SIZE;
use crate::concurrency::TxIndex;
use crate::default_scheduler;

const DEFAULT_CHUNK_SIZE: usize = 100;

#[rstest]
fn test_new(#[values(0, 1, 32)] chunk_size: usize) {
let scheduler = Scheduler::new(chunk_size);
Expand Down
8 changes: 8 additions & 0 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use crate::test_utils::dict_state_reader::DictStateReader;
use crate::transaction::account_transaction::AccountTransaction;
use crate::transaction::transactions::ExecutableTransaction;

// Public Consts.

pub const DEFAULT_CHUNK_SIZE: usize = 64;

// Fixtures.

#[fixture]
Expand Down Expand Up @@ -54,13 +58,17 @@ macro_rules! default_scheduler {
};
}

// Concurrency constructors.

// TODO(meshi, 01/06/2024): Consider making this a macro.
pub fn safe_versioned_state_for_testing(
block_state: CachedState<DictStateReader>,
) -> ThreadSafeVersionedState<CachedState<DictStateReader>> {
ThreadSafeVersionedState::new(VersionedState::new(block_state))
}

// Utils.

// Note: this function does not mutate the state.
pub fn create_fee_transfer_call_info<S: StateReader>(
state: &mut CachedState<S>,
Expand Down
5 changes: 5 additions & 0 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
.into_inner()
.expect("No other mutex should hold the versioned state while calling this method.")
}

#[cfg(test)]
pub fn state(&self) -> LockedVersionedState<'_, S> {
self.0.lock().expect("Failed to acquire state lock.")
}
}

impl<S: StateReader> Clone for ThreadSafeVersionedState<S> {
Expand Down

0 comments on commit 90e7e69

Please sign in to comment.