Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
test(concurrency): multithreaded test for decrease_validation_index
Browse files Browse the repository at this point in the history
  • Loading branch information
barak-b-starkware committed May 6, 2024
1 parent 9cf97eb commit 03d3956
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
110 changes: 110 additions & 0 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use rstest::rstest;
use starknet_api::{contract_address, stark_felt};
use starknet_api::patricia_key;
use starknet_api::core::ContractAddress;
use starknet_api::hash::StarkFelt;
use starknet_api::core::PatriciaKey;
use starknet_api::hash::StarkHash;

use crate::test_utils::dict_state_reader::DictStateReader;
use crate::concurrency::scheduler::{Scheduler, Task};
use crate::concurrency::test_utils::safe_versioned_state_for_testing;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::state_api::StateReader;
use std::time::Duration;
// use crate::concurrency::TxIndex;
use crate::{default_scheduler, nonce, storage_key};

const DEFAULT_CHUNK_SIZE: usize = 100;


#[rstest]
fn aaaflow_test() {
let scheduler = Arc::new(Scheduler::new(2_u8.pow(4).into()));
let versioned_state = safe_versioned_state_for_testing(DictStateReader::default());
let num_threads = 2_u8.pow(5);
let mut handles = vec![];
for i in 0..num_threads {
let scheduler = Arc::clone(&scheduler);
let versioned_state = versioned_state.clone();
let handle = std::thread::spawn(move || {
let mut task = Task::NoTask;
while !scheduler.done() {
match task {
Task::ExecutionTask(tx_index) => {
let versioned_state_proxy = versioned_state.pin_version(tx_index);
let current_nonce: u64 = usize::try_from(versioned_state_proxy.get_nonce_at(contract_address!("0x1")).unwrap().0).unwrap().try_into().unwrap();
let write_set = StateMaps {
storage: HashMap::from([((contract_address!("0x2"), storage_key!(27_u8)), stark_felt!(i))]),
nonces: HashMap::from([(contract_address!("0x1"), nonce!(current_nonce + 1))]),
..Default::default()
};
versioned_state_proxy.apply_writes(&write_set, &ContractClassMapping::default());
scheduler.finish_execution(tx_index);
task = Task::NoTask;
}
Task::ValidationTask(tx_index) => {
let versioned_state_proxy = versioned_state.pin_version(tx_index);
let current_nonce: u64 = usize::try_from(versioned_state_proxy.get_nonce_at(contract_address!("0x1")).unwrap().0).unwrap().try_into().unwrap();
let read_set = StateMaps {
nonces: HashMap::from([(contract_address!("0x1"), nonce!(current_nonce))]),
..Default::default()
};
let read_set_valid = versioned_state_proxy.validate_read_set(&read_set);
let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index);
if aborted { versioned_state_proxy.delete_writes(tx_index) }
task = match scheduler.finish_validation(tx_index, aborted) {
Some(execution_task) => execution_task,
None => Task::NoTask,
}
}
Task::NoTask => {
task = scheduler.next_task();
}
Task::Done => (),
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}

#[rstest]
fn test_mt_decrease_validation_index() {
let target_index = 0;
let initial_validation_index = 1;
let decrease_counter_calls = 2;
let scheduler = Arc::new(
default_scheduler!(chunk_size: DEFAULT_CHUNK_SIZE, validation_index: initial_validation_index),
);
let mut handles = vec![];
for _ in 0..decrease_counter_calls {
let cloned_scheduler = Arc::clone(&scheduler);
let cloned2_scheduler = Arc::clone(&scheduler);
// Simulating part of `finish_execution` that decreases the validation index.
let handle_decrease_validation_index = std::thread::spawn(move || {
std::thread::sleep(Duration::from_secs(1));
cloned_scheduler.decrease_validation_index(target_index);
});
// Simulating another thread that calls the `next_task()` method and bypassing it by artificially advancing the validation index.
let handle_advance_validation_index = std::thread::spawn(move || {

cloned2_scheduler.next_version_to_validate();
});
handles.push(handle_decrease_validation_index);
handles.push(handle_advance_validation_index);
}
for handle in handles {
handle.join().unwrap();
}
let final_validation_index = scheduler.validation_index.load(Ordering::Acquire);
let final_decrease_counter = scheduler.decrease_counter.load(Ordering::Acquire);
assert_eq!(final_decrease_counter + final_validation_index, initial_validation_index + decrease_counter_calls - target_index );
}
4 changes: 4 additions & 0 deletions crates/blockifier/src/concurrency/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ use crate::concurrency::TxIndex;
#[path = "scheduler_test.rs"]
pub mod test;

#[cfg(test)]
#[path = "flow_test.rs"]
pub mod flow_test;

// TODO(Avi, 01/04/2024): Remove dead_code attribute.
#[allow(dead_code)]
#[derive(Debug, Default)]
Expand Down
30 changes: 27 additions & 3 deletions crates/blockifier/src/concurrency/versioned_state_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ impl<S: StateReader> VersionedState<S> {
self.compiled_contract_classes.write(tx_index, key, value.clone());
}
}

fn delete_writes(&mut self, tx_index: TxIndex) {
self.storage.delete_writes(tx_index);
self.nonces.delete_writes(tx_index);
self.class_hashes.delete_writes(tx_index);
self.compiled_class_hashes.delete_writes(tx_index);
self.compiled_contract_classes.delete_writes(tx_index);
}
}

pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<VersionedState<S>>>);
Expand All @@ -150,18 +158,30 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
}

pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
VersionedStateProxy { tx_index, state: self.clone() }
}
}

impl<S: StateReader> Clone for ThreadSafeVersionedState<S> {
fn clone(&self) -> Self {
ThreadSafeVersionedState(Arc::clone(&self.0))
}
}

pub struct VersionedStateProxy<S: StateReader> {
pub tx_index: TxIndex,
pub state: Arc<Mutex<VersionedState<S>>>,
pub state: ThreadSafeVersionedState<S>,
}

impl<S: StateReader> Clone for VersionedStateProxy<S> {
fn clone(&self) -> Self {
VersionedStateProxy {tx_index: self.tx_index, state: self.state.clone()}
}
}

impl<S: StateReader> VersionedStateProxy<S> {
fn state(&self) -> LockedVersionedState<'_, S> {
self.state.lock().expect("Failed to acquire state lock.")
self.state.0.lock().expect("Failed to acquire state lock.")
}

pub fn validate_read_set(&self, reads: &StateMaps) -> bool {
Expand All @@ -171,6 +191,10 @@ impl<S: StateReader> VersionedStateProxy<S> {
pub fn apply_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}

pub fn delete_writes(&self, tx_index: TxIndex) {
self.state().delete_writes(tx_index);
}
}

impl<S: StateReader> StateReader for VersionedStateProxy<S> {
Expand Down
6 changes: 6 additions & 0 deletions crates/blockifier/src/concurrency/versioned_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ where
cell.insert(tx_index, value);
}

pub fn delete_writes(&mut self, tx_index: TxIndex) {
for (_, inner_map) in self.writes.iter_mut() {
inner_map.retain(|&index, _| index != tx_index);
}
}

/// This method inserts the provided key-value pair into the cached initial values map.
/// It is typically used when reading a value that is not found in the versioned storage. In
/// such a scenario, the value is retrieved from the initial storage and written to the
Expand Down

0 comments on commit 03d3956

Please sign in to comment.