Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
feat(concurrency): delete writes (#1892)
Browse files Browse the repository at this point in the history
  • Loading branch information
barak-b-starkware authored May 20, 2024
1 parent 4c760c7 commit 63c8b7e
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 24 deletions.
19 changes: 19 additions & 0 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
use rstest::fixture;
use starknet_api::core::{ClassHash, ContractAddress, PatriciaKey};
use starknet_api::hash::StarkHash;
use starknet_api::{class_hash, contract_address, patricia_key};

use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedState};
use crate::context::BlockContext;
use crate::execution::call_info::CallInfo;
Expand All @@ -7,6 +12,20 @@ use crate::test_utils::dict_state_reader::DictStateReader;
use crate::transaction::account_transaction::AccountTransaction;
use crate::transaction::transactions::ExecutableTransaction;

// Fixtures.

#[fixture]
pub fn contract_address() -> ContractAddress {
contract_address!("0x18031991")
}

#[fixture]
pub fn class_hash() -> ClassHash {
class_hash!(27_u8)
}

// Macros.

#[macro_export]
macro_rules! default_scheduler {
($chunk_size:ident : $chunk:expr , $($field:ident $(: $value:expr)?),+ $(,)?) => {
Expand Down
56 changes: 49 additions & 7 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const READ_ERR: &str = "Error: read value missing in the versioned storage";
/// Represents a versioned state used as shared state between a chunk of workers.
/// This state facilitates concurrent operations.
/// Reader functionality is injected through initial state.
#[derive(Debug)]
pub struct VersionedState<S: StateReader> {
initial_state: S,
storage: VersionedStorage<(ContractAddress, StorageKey), StarkFelt>,
Expand All @@ -42,12 +43,25 @@ impl<S: StateReader> VersionedState<S> {
}
}

fn get_writes(&mut self, from_index: TxIndex) -> StateMaps {
fn get_writes_up_to_index(&mut self, tx_index: TxIndex) -> StateMaps {
StateMaps {
storage: self.storage.get_writes_from_index(from_index),
nonces: self.nonces.get_writes_from_index(from_index),
class_hashes: self.class_hashes.get_writes_from_index(from_index),
compiled_class_hashes: self.compiled_class_hashes.get_writes_from_index(from_index),
storage: self.storage.get_writes_up_to_index(tx_index),
nonces: self.nonces.get_writes_up_to_index(tx_index),
class_hashes: self.class_hashes.get_writes_up_to_index(tx_index),
compiled_class_hashes: self.compiled_class_hashes.get_writes_up_to_index(tx_index),
// TODO(OriF, 01/07/2024): Update declared_contracts initial value.
declared_contracts: HashMap::new(),
}
}

#[cfg(any(feature = "testing", test))]
pub fn get_writes_of_index(&self, tx_index: TxIndex) -> StateMaps {
StateMaps {
storage: self.storage.get_writes_of_index(tx_index),
nonces: self.nonces.get_writes_of_index(tx_index),
class_hashes: self.class_hashes.get_writes_of_index(tx_index),
compiled_class_hashes: self.compiled_class_hashes.get_writes_of_index(tx_index),
// TODO(OriF, 01/07/2024): Update declared_contracts initial value.
declared_contracts: HashMap::new(),
}
}
Expand All @@ -56,11 +70,11 @@ impl<S: StateReader> VersionedState<S> {
where
T: StateReader,
{
let writes = self.get_writes(from_index);
let writes = self.get_writes_up_to_index(from_index);
parent_state.update_cache(writes);

parent_state.update_contract_class_cache(
self.compiled_contract_classes.get_writes_from_index(from_index),
self.compiled_contract_classes.get_writes_up_to_index(from_index),
);
}

Expand Down Expand Up @@ -139,6 +153,30 @@ impl<S: StateReader> VersionedState<S> {
self.compiled_contract_classes.write(tx_index, key, value.clone());
}
}

fn delete_writes(
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) {
for &key in writes.storage.keys() {
self.storage.delete_write(key, tx_index);
}
for &key in writes.nonces.keys() {
self.nonces.delete_write(key, tx_index);
}
for &key in writes.class_hashes.keys() {
self.class_hashes.delete_write(key, tx_index);
}
for &key in writes.compiled_class_hashes.keys() {
self.compiled_class_hashes.delete_write(key, tx_index);
}
// TODO(OriF, 01/07/2024): Add a for loop for `declared_contracts`.
for &key in class_hash_to_class.keys() {
self.compiled_contract_classes.delete_write(key, tx_index);
}
}
}

pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<VersionedState<S>>>);
Expand Down Expand Up @@ -177,6 +215,10 @@ impl<S: StateReader> VersionedStateProxy<S> {
pub fn apply_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}

pub fn delete_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) {
self.state().delete_writes(self.tx_index, writes, class_hash_to_class);
}
}

impl<S: StateReader> StateReader for VersionedStateProxy<S> {
Expand Down
135 changes: 120 additions & 15 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ use starknet_api::transaction::{Calldata, ContractAddressSalt, Fee, TransactionV
use starknet_api::{calldata, class_hash, contract_address, patricia_key, stark_felt};

use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address};
use crate::concurrency::test_utils::safe_versioned_state_for_testing;
use crate::concurrency::test_utils::{
class_hash, contract_address, safe_versioned_state_for_testing,
};
use crate::concurrency::versioned_state::{
ThreadSafeVersionedState, VersionedState, VersionedStateProxy,
};
use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{CachedState, StateMaps};
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::{State, StateReader};
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::deploy_account::deploy_account_tx;
Expand All @@ -27,19 +30,6 @@ use crate::transaction::test_utils::l1_resource_bounds;
use crate::transaction::transactions::ExecutableTransaction;
use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key};

const TEST_CONTRACT_ADDRESS: &str = "0x1";
const TEST_CLASS_HASH: u8 = 27_u8;

#[fixture]
pub fn contract_address() -> ContractAddress {
contract_address!(TEST_CONTRACT_ADDRESS)
}

#[fixture]
pub fn class_hash() -> ClassHash {
class_hash!(TEST_CLASS_HASH)
}

#[fixture]
pub fn safe_versioned_state(
contract_address: ContractAddress,
Expand Down Expand Up @@ -371,3 +361,118 @@ fn test_apply_writes_reexecute_scenario(
// The class hash should be updated.
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
}

#[rstest]
fn test_delete_writes(
#[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex,
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let num_of_txs = 3;
let mut transactional_states: Vec<CachedState<VersionedStateProxy<DictStateReader>>> =
(0..num_of_txs).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect();
// Setting 2 instances of the contract to ensure `delete_writes` removes information from
// multiple keys. Class hash values are not checked in this test.
let contract_addresses = [
(contract_address!("0x100"), class_hash!(20_u8)),
(contract_address!("0x200"), class_hash!(21_u8)),
];
let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
for tx_state in transactional_states.iter_mut() {
// Modify the `cache` member of the CachedState.
for (contract_address, class_hash) in contract_addresses.iter() {
tx_state.set_class_hash_at(*contract_address, *class_hash).unwrap();
}
// Modify the `class_hash_to_class` member of the CachedState.
tx_state
.set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class())
.unwrap();
tx_state
.state
.apply_writes(&tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow());
}

transactional_states[tx_index_to_delete_writes].state.delete_writes(
&transactional_states[tx_index_to_delete_writes].cache.borrow().writes,
&transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(),
);

for tx_index in 0..num_of_txs {
let should_be_empty = tx_index == tx_index_to_delete_writes;
assert_eq!(
safe_versioned_state
.0
.lock()
.unwrap()
.get_writes_of_index(tx_index)
.class_hashes
.is_empty(),
should_be_empty
);

assert_eq!(
safe_versioned_state
.0
.lock()
.unwrap()
.compiled_contract_classes
.get_writes_of_index(tx_index)
.is_empty(),
should_be_empty
);
}
}

#[rstest]
fn test_delete_writes_completeness(
safe_versioned_state: ThreadSafeVersionedState<DictStateReader>,
) {
let state_maps_writes = StateMaps {
nonces: HashMap::from([(contract_address!("0x1"), nonce!("0x1"))]),
class_hashes: HashMap::from([(contract_address!("0x1"), class_hash!("0x1"))]),
storage: HashMap::from([(
(contract_address!("0x1"), storage_key!("0x1")),
stark_felt!("0x1"),
)]),
compiled_class_hashes: HashMap::from([(class_hash!("0x1"), compiled_class_hash!("0x1"))]),
// TODO (OriF, 01/07/2024): Uncomment the following line and remove the line below it once
// `declared_contracts` mapping logic in StateMaps is complete.
// declared_contracts: HashMap::from([(class_hash!("0x1"), true)]),
declared_contracts: HashMap::default(),
};
let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let class_hash_to_class_writes =
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]);

let tx_index = 0;
let versioned_state_proxy = safe_versioned_state.pin_version(tx_index);

versioned_state_proxy.apply_writes(&state_maps_writes, &class_hash_to_class_writes);
assert_eq!(
safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index),
state_maps_writes
);
assert_eq!(
safe_versioned_state
.0
.lock()
.unwrap()
.compiled_contract_classes
.get_writes_of_index(tx_index),
class_hash_to_class_writes
);

versioned_state_proxy.delete_writes(&state_maps_writes, &class_hash_to_class_writes);
assert_eq!(
safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index),
StateMaps::default()
);
assert_eq!(
safe_versioned_state
.0
.lock()
.unwrap()
.compiled_contract_classes
.get_writes_of_index(tx_index),
ContractClassMapping::default()
);
}
26 changes: 24 additions & 2 deletions crates/blockifier/src/concurrency/versioned_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod test;
/// It is versioned in the sense that it holds a state of write operations done on it by
/// different versions of executions.
/// This allows maintaining the cells with the correct values in the context of each execution.
#[derive(Debug)]
pub struct VersionedStorage<K, V>
where
K: Clone + Copy + Eq + Hash + Debug,
Expand Down Expand Up @@ -50,6 +51,16 @@ where
cell.insert(tx_index, value);
}

pub fn delete_write(&mut self, key: K, tx_index: TxIndex) {
self.writes
.get_mut(&key)
.expect(
"A 'delete_write' call must be preceded by a 'write' call with the corresponding \
key",
)
.remove(&tx_index);
}

/// This method inserts the provided key-value pair into the cached initial values map.
/// It is typically used when reading a value that is not found in the versioned storage. In
/// such a scenario, the value is retrieved from the initial storage and written to the
Expand All @@ -58,13 +69,24 @@ where
self.cached_initial_values.insert(key, value);
}

pub(crate) fn get_writes_from_index(&self, from_index: TxIndex) -> HashMap<K, V> {
pub(crate) fn get_writes_up_to_index(&self, index: TxIndex) -> HashMap<K, V> {
let mut writes = HashMap::default();
for (&key, cell) in self.writes.iter() {
if let Some(value) = cell.range(..=from_index).next_back() {
if let Some(value) = cell.range(..=index).next_back() {
writes.insert(key, value.1.clone());
}
}
writes
}

#[cfg(any(feature = "testing", test))]
pub fn get_writes_of_index(&self, tx_index: TxIndex) -> HashMap<K, V> {
let mut writes = HashMap::default();
for (&key, cell) in self.writes.iter() {
if let Some(value) = cell.get(&tx_index) {
writes.insert(key, value.clone());
}
}
writes
}
}
33 changes: 33 additions & 0 deletions crates/blockifier/src/concurrency/versioned_storage_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
use std::collections::{BTreeMap, HashMap};

use pretty_assertions::assert_eq;
use rstest::rstest;
use starknet_api::core::{ClassHash, ContractAddress};

use crate::concurrency::test_utils::{class_hash, contract_address};
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;

// TODO(barak, 01/07/2024): Split into test_read() and test_write().
#[test]
fn test_versioned_storage() {
let mut storage = VersionedStorage::default();
Expand Down Expand Up @@ -37,3 +44,29 @@ fn test_versioned_storage() {
// Test the write.
assert_eq!(storage.read(50, 100).unwrap(), 194);
}

#[rstest]
fn test_delete_write(
contract_address: ContractAddress,
class_hash: ClassHash,
#[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex,
) {
// TODO(barak, 01/07/2025): Create a macro versioned_storage!.
let num_of_txs = 3;
let mut versioned_storage = VersionedStorage {
cached_initial_values: HashMap::default(),
writes: HashMap::from([(
contract_address,
// Class hash values are not checked in this test.
BTreeMap::from_iter((0..num_of_txs).map(|i| (i, class_hash))),
)]),
};
for tx_index in 0..num_of_txs {
let should_contain_tx_index_writes = tx_index != tx_index_to_delete_writes;
versioned_storage.delete_write(contract_address, tx_index_to_delete_writes);
assert_eq!(
versioned_storage.writes.get(&contract_address).unwrap().contains_key(&tx_index),
should_contain_tx_index_writes
)
}
}

0 comments on commit 63c8b7e

Please sign in to comment.