Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

test(concurrency): multithreaded flow test for the scheduler #1830

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions crates/blockifier/src/abi/sierra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@ pub trait SierraType: Sized {
fn from_memory(vm: &VirtualMachine, ptr: &mut Relocatable) -> SierraTypeResult<Self>;

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self>;
}

// Utils.

pub fn felt_to_u128(felt: &Felt252) -> Result<u128, SierraTypeError> {
felt.to_u128()
.ok_or_else(|| SierraTypeError::ValueTooLargeForType { val: felt.clone(), ty: "u128" })
}

pub fn stark_felt_to_u128(stark_felt: &StarkFelt) -> Result<u128, SierraTypeError> {
felt_to_u128(&stark_felt_to_felt(*stark_felt))
}

// TODO(barak, 01/10/2023): Move to starknet_api under StorageKey implementation.
pub fn next_storage_key(key: &StorageKey) -> Result<StorageKey, StarknetApiError> {
Ok(StorageKey(PatriciaKey::try_from(StarkFelt::from(
Expand Down Expand Up @@ -78,7 +83,7 @@ impl SierraType for SierraU128 {
}

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self> {
Expand Down Expand Up @@ -111,7 +116,7 @@ impl SierraType for SierraU256 {
}

fn from_storage(
state: &mut dyn StateReader,
state: &dyn StateReader,
contract_address: &ContractAddress,
key: &StorageKey,
) -> SierraTypeResult<Self> {
Expand Down
174 changes: 174 additions & 0 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use rstest::rstest;
use starknet_api::core::{ContractAddress, PatriciaKey};
use starknet_api::hash::{StarkFelt, StarkHash};
use starknet_api::{contract_address, patricia_key, stark_felt};

use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::storage_key;
use crate::test_utils::dict_state_reader::DictStateReader;

const CONTRACT_ADDRESS: &str = "0x18031991";
const STORAGE_KEY: u8 = 27;

#[rstest]
fn scheduler_flow_test(
// TODO(barak, 01/07/2024): Add a separate identical test and use the package loom.
#[values(1, 2, 4, 32, 64, 128)] num_threads: u8,
) {
// Tests the Scheduler under a heavy load of validation aborts. To do that, we simulate multiple
// transactions with multiple threads, where every transaction depends on its predecessor. Each
// transaction sequentially advances a counter by reading the previous value and bumping it by
// 1.
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
let versioned_state =
safe_versioned_state_for_testing(CachedState::from(DictStateReader::default()));
let mut handles = vec![];

for _ in 0..num_threads {
let scheduler = Arc::clone(&scheduler);
let versioned_state = versioned_state.clone();
let handle = std::thread::spawn(move || {
let mut task = Task::NoTask;
loop {
if let Some(mut transaction_committer) = scheduler.try_enter_commit_phase() {
while let Some(tx_index) = transaction_committer.try_commit() {
let mut state_proxy = versioned_state.pin_version(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let reads_valid = state_proxy.validate_reads(&reads);
if !reads_valid {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
let (_, new_writes) = get_reads_writes_for(
Task::ExecutionTask(tx_index),
&versioned_state,
);
state_proxy.apply_writes(
&new_writes,
&ContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution_during_commit(tx_index);
}
}
}
task = match task {
Task::ExecutionTask(tx_index) => {
let (_, writes) =
get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state);
versioned_state.pin_version(tx_index).apply_writes(
&writes,
&ContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution(tx_index);
Task::NoTask
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let read_set_valid = state_proxy.validate_reads(&reads);
let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index);
if aborted {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
scheduler.finish_abort(tx_index)
} else {
Task::NoTask
}
}
Task::NoTask => scheduler.next_task(),
Task::Done => break,
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}

// The execution index can be strictly greater than chunk_size. This is a side effect of using
// atomic variables instead of locks, which can encapsulate both the check (whether to increment
// the variable or not) and the incrementation in a scope where no other threads can access the
// variable.
assert!(scheduler.execution_index.load(Ordering::Acquire) >= DEFAULT_CHUNK_SIZE);
// There is no guarantee about the validation index because of the use of the commit index.
assert!(*scheduler.commit_index.lock().unwrap() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.get_n_committed_txs() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.done_marker.load(Ordering::Acquire));
let inner_versioned_state = versioned_state.into_inner_state();
for tx_index in 0..DEFAULT_CHUNK_SIZE {
assert_eq!(*scheduler.tx_statuses[tx_index].lock().unwrap(), TransactionStatus::Committed);
let storage_writes = inner_versioned_state.get_writes_of_index(tx_index).storage;
assert_eq!(
*storage_writes
.get(&(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)))
.unwrap(),
stark_felt!(format!("{:x}", tx_index + 1).as_str())
);
}
}

fn get_reads_writes_for(
task: Task,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader>>,
) -> (StateMaps, StateMaps) {
match task {
Task::ExecutionTask(tx_index) => {
let state_proxy = match tx_index {
0 => {
return (
state_maps_with_single_storage_entry(0),
state_maps_with_single_storage_entry(1),
);
}
_ => versioned_state.pin_version(tx_index - 1),
};
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value),
state_maps_with_single_storage_entry(tx_written_value + 1),
)
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value - 1),
state_maps_with_single_storage_entry(tx_written_value),
)
}
_ => panic!("Only execution and validation tasks shold be used here."),
}
}

fn state_maps_with_single_storage_entry(value: u128) -> StateMaps {
StateMaps {
storage: HashMap::from([(
(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)),
stark_felt!(value),
)]),
..Default::default()
}
}
9 changes: 4 additions & 5 deletions crates/blockifier/src/concurrency/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::cmp::min;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Mutex, MutexGuard, TryLockError};

Expand All @@ -9,6 +8,10 @@ use crate::concurrency::TxIndex;
#[path = "scheduler_test.rs"]
pub mod test;

#[cfg(test)]
#[path = "flow_test.rs"]
pub mod flow_test;

pub struct TransactionCommitter<'a> {
scheduler: &'a Scheduler,
commit_index_guard: MutexGuard<'a, usize>,
Expand Down Expand Up @@ -88,10 +91,6 @@ impl Scheduler {
let index_to_validate = self.validation_index.load(Ordering::Acquire);
let index_to_execute = self.execution_index.load(Ordering::Acquire);

if min(index_to_validate, index_to_execute) >= self.chunk_size {
return Task::NoTask;
}

if index_to_validate < index_to_execute {
if let Some(tx_index) = self.next_version_to_validate() {
return Task::ValidationTask(tx_index);
Expand Down
3 changes: 1 addition & 2 deletions crates/blockifier/src/concurrency/scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use pretty_assertions::assert_eq;
use rstest::rstest;

use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::DEFAULT_CHUNK_SIZE;
use crate::concurrency::TxIndex;
use crate::default_scheduler;

const DEFAULT_CHUNK_SIZE: usize = 100;

#[rstest]
fn test_new(#[values(0, 1, 32)] chunk_size: usize) {
let scheduler = Scheduler::new(chunk_size);
Expand Down
8 changes: 8 additions & 0 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use crate::test_utils::dict_state_reader::DictStateReader;
use crate::transaction::account_transaction::AccountTransaction;
use crate::transaction::transactions::ExecutableTransaction;

// Public Consts.

pub const DEFAULT_CHUNK_SIZE: usize = 64;

// Fixtures.

#[fixture]
Expand Down Expand Up @@ -54,13 +58,17 @@ macro_rules! default_scheduler {
};
}

// Concurrency constructors.

// TODO(meshi, 01/06/2024): Consider making this a macro.
pub fn safe_versioned_state_for_testing(
block_state: CachedState<DictStateReader>,
) -> ThreadSafeVersionedState<CachedState<DictStateReader>> {
ThreadSafeVersionedState::new(VersionedState::new(block_state))
}

// Utils.

// Note: this function does not mutate the state.
pub fn create_fee_transfer_call_info<S: StateReader>(
state: &mut CachedState<S>,
Expand Down
Loading