Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
refactor(concurrency): use apply_writes in the commit of VersionedSta…
Browse files Browse the repository at this point in the history
…te (#1948)
  • Loading branch information
barak-b-starkware authored Jun 6, 2024
1 parent 2327bdb commit 702f401
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
38 changes: 22 additions & 16 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use starknet_api::state::StorageKey;
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};

Expand Down Expand Up @@ -71,18 +71,6 @@ impl<S: StateReader> VersionedState<S> {
}
}

pub fn commit<T>(&mut self, from_index: TxIndex, parent_state: &mut CachedState<T>)
where
T: StateReader,
{
let writes = self.get_writes_up_to_index(from_index);

parent_state.update_cache(
&writes,
self.compiled_contract_classes.get_writes_up_to_index(from_index),
);
}

// TODO(Mohammad, 01/04/2024): Store the read set (and write set) within a shared
// object (probabily `VersionedState`). As RefCell operations are not thread-safe. Therefore,
// accessing this function should be protected by a mutex to ensure thread safety.
Expand Down Expand Up @@ -199,12 +187,31 @@ impl<S: StateReader> VersionedState<S> {
}
}

#[allow(dead_code)]
fn into_initial_state(self) -> S {
self.initial_state
}
}

impl<U: UpdatableState> VersionedState<U> {
pub fn commit_chunk_and_recover_block_state(mut self, n_committed_txs: usize) -> U {
if n_committed_txs == 0 {
return self.into_initial_state();
}
let commit_index = n_committed_txs - 1;
let writes = self.get_writes_up_to_index(commit_index);
let class_hash_to_class =
self.compiled_contract_classes.get_writes_up_to_index(commit_index);
let mut state = self.into_initial_state();
// TODO(barak, 01/08/2024): Add visited_pcs argument to `apply_writes`.
state.apply_writes(&writes, &class_hash_to_class, &HashMap::default());
state
}
}

// TODO(barak, 01/07/2024): Re-consider the API (pub functions) of VersionedState,
// ThreadSafeVersionedState and VersionedStateProxy.
// TODO(barak, 01/07/2024): Re-consider the necessity ot ThreadSafeVersionedState once the worker
// logic is completed.
pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<VersionedState<S>>>);
pub type LockedVersionedState<'a, S> = MutexGuard<'a, VersionedState<S>>;

Expand All @@ -217,8 +224,7 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
}

#[allow(dead_code)]
fn into_inner_state(self) -> VersionedState<S> {
pub fn into_inner_state(self) -> VersionedState<S> {
Arc::try_unwrap(self.0)
.unwrap_or_else(|_| {
panic!(
Expand Down
13 changes: 9 additions & 4 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,6 @@ fn test_versioned_proxy_state_flow(
let contract_address = contract_address!("0x1");
let class_hash = ClassHash(stark_felt!(27_u8));

let mut block_state = CachedState::from(DictStateReader::default());
let mut versioned_proxy_states: Vec<VersionedStateProxy<CachedState<DictStateReader>>> =
(0..4).map(|i| safe_versioned_state.pin_version(i)).collect();

Expand Down Expand Up @@ -566,8 +565,14 @@ fn test_versioned_proxy_state_flow(
}

// Check the final state.
safe_versioned_state.0.lock().unwrap().commit(4, &mut block_state);
for proxy in versioned_proxy_states {
drop(proxy);
}
let modified_block_state =
safe_versioned_state.into_inner_state().commit_chunk_and_recover_block_state(4);

assert!(block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3);
assert!(block_state.get_compiled_contract_class(class_hash).unwrap() == contract_class_2);
assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3);
assert!(
modified_block_state.get_compiled_contract_class(class_hash).unwrap() == contract_class_2
);
}

0 comments on commit 702f401

Please sign in to comment.