Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
fix(concurrency): change version to tx_index (#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
barak-b-starkware authored Apr 3, 2024
1 parent ebcbf44 commit 0899890
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion crates/blockifier/src/concurrency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pub mod scheduler;
pub mod versioned_state_proxy;
pub mod versioned_storage;

type Version = u64;
type TxIndex = usize;
16 changes: 8 additions & 8 deletions crates/blockifier/src/concurrency/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::concurrency::Version;
use crate::concurrency::TxIndex;

#[cfg(test)]
#[path = "scheduler_test.rs"]
Expand Down Expand Up @@ -57,18 +57,18 @@ impl Scheduler {
todo!()
}

fn decrease_validation_index(&self, target_index: usize) {
fn decrease_validation_index(&self, target_index: TxIndex) {
self.validation_index.fetch_min(target_index, Ordering::SeqCst);
self.decrease_counter.fetch_add(1, Ordering::SeqCst);
}

fn decrease_execution_index(&self, target_index: usize) {
fn decrease_execution_index(&self, target_index: TxIndex) {
self.execution_index.fetch_min(target_index, Ordering::SeqCst);
self.decrease_counter.fetch_add(1, Ordering::SeqCst);
}

/// Updates a transaction's status to `Executing` if it is ready to execute.
fn try_incarnate(&self, tx_index: usize) -> Option<usize> {
fn try_incarnate(&self, tx_index: TxIndex) -> Option<TxIndex> {
if tx_index < self.chunk_size {
// TODO(barak, 01/04/2024): complete try_incarnate logic.
return Some(tx_index);
Expand All @@ -77,7 +77,7 @@ impl Scheduler {
None
}

fn next_version_to_validate(&self) -> Option<usize> {
fn next_version_to_validate(&self) -> Option<TxIndex> {
let index_to_validate = self.validation_index.load(Ordering::Acquire);
if index_to_validate >= self.chunk_size {
self.check_done();
Expand All @@ -93,7 +93,7 @@ impl Scheduler {
None
}

fn next_version_to_execute(&self) -> Option<usize> {
fn next_version_to_execute(&self) -> Option<TxIndex> {
let index_to_execute = self.execution_index.load(Ordering::Acquire);
if index_to_execute >= self.chunk_size {
self.check_done();
Expand All @@ -106,8 +106,8 @@ impl Scheduler {
}

pub enum Task {
ExecutionTask(Version),
ValidationTask(Version),
ExecutionTask(TxIndex),
ValidationTask(TxIndex),
NoTask,
Done,
}
58 changes: 29 additions & 29 deletions crates/blockifier/src/concurrency/versioned_state_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use starknet_api::hash::StarkFelt;
use starknet_api::state::StorageKey;

use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::Version;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::{ContractClassMapping, StateCache};
use crate::state::state_api::{State, StateReader, StateResult};
Expand Down Expand Up @@ -48,33 +48,33 @@ impl<S: StateReader> VersionedState<S> {
// TODO(Mohammad, 01/04/2024): Store the read set (and write set) within a shared
// object (probabily `VersionedState`). As RefCell operations are not thread-safe. Therefore,
// accessing this function should be protected by a mutex to ensure thread safety.
pub fn validate_read_set(&mut self, version: Version, state_cache: &mut StateCache) -> bool {
// If the version is 0, then the read set is valid. Since version 0 has no predecessors,
// there's nothing to compare it to.
if version == 0 {
pub fn validate_read_set(&mut self, tx_index: TxIndex, state_cache: &mut StateCache) -> bool {
// If is the first transaction in the chunk, then the read set is valid. Since it has no
// predecessors, there's nothing to compare it to.
if tx_index == 0 {
return true;
}
for (&(contract_address, storage_key), expected_value) in
&state_cache.storage_initial_values
{
let value =
self.storage.read(version, (contract_address, storage_key)).expect(READ_ERR);
self.storage.read(tx_index, (contract_address, storage_key)).expect(READ_ERR);

if &value != expected_value {
return false;
}
}

for (&contract_address, expected_value) in &state_cache.nonce_initial_values {
let value = self.nonces.read(version, contract_address).expect(READ_ERR);
let value = self.nonces.read(tx_index, contract_address).expect(READ_ERR);

if &value != expected_value {
return false;
}
}

for (&contract_address, expected_value) in &state_cache.class_hash_initial_values {
let value = self.class_hashes.read(version, contract_address).expect(READ_ERR);
let value = self.class_hashes.read(tx_index, contract_address).expect(READ_ERR);

if &value != expected_value {
return false;
Expand All @@ -83,7 +83,7 @@ impl<S: StateReader> VersionedState<S> {

// Added for symmetry. We currently do not update this initial mapping.
for (&class_hash, expected_value) in &state_cache.compiled_class_hash_initial_values {
let value = self.compiled_class_hashes.read(version, class_hash).expect(READ_ERR);
let value = self.compiled_class_hashes.read(tx_index, class_hash).expect(READ_ERR);

if &value != expected_value {
return false;
Expand All @@ -99,38 +99,38 @@ impl<S: StateReader> VersionedState<S> {

pub fn apply_writes(
&mut self,
version: Version,
tx_index: TxIndex,
state_cache: &mut StateCache,
class_hash_to_class: ContractClassMapping,
) {
for (&key, &value) in &state_cache.storage_writes {
self.storage.write(version, key, value);
self.storage.write(tx_index, key, value);
}
for (&key, &value) in &state_cache.nonce_writes {
self.nonces.write(version, key, value);
self.nonces.write(tx_index, key, value);
}
for (&key, &value) in &state_cache.class_hash_writes {
self.class_hashes.write(version, key, value);
self.class_hashes.write(tx_index, key, value);
}
for (&key, &value) in &state_cache.compiled_class_hash_writes {
self.compiled_class_hashes.write(version, key, value);
self.compiled_class_hashes.write(tx_index, key, value);
}
for (key, value) in class_hash_to_class {
self.compiled_contract_classes.write(version, key, value.clone());
self.compiled_contract_classes.write(tx_index, key, value.clone());
}
}
}

pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<VersionedState<S>>>);

impl<S: StateReader> ThreadSafeVersionedState<S> {
pub fn pin_version(&self, version: Version) -> VersionedStateProxy<S> {
VersionedStateProxy { version, state: self.0.clone() }
pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
}
}

pub struct VersionedStateProxy<S: StateReader> {
pub version: Version,
pub tx_index: TxIndex,
pub state: Arc<Mutex<VersionedState<S>>>,
}

Expand All @@ -149,7 +149,7 @@ impl<S: StateReader> State for VersionedStateProxy<S> {
value: StarkFelt,
) -> StateResult<()> {
let mut state = self.state();
state.storage.write(self.version, (contract_address, key), value);
state.storage.write(self.tx_index, (contract_address, key), value);

Ok(())
}
Expand All @@ -160,20 +160,20 @@ impl<S: StateReader> State for VersionedStateProxy<S> {
class_hash: ClassHash,
) -> StateResult<()> {
let mut state = self.state();
state.class_hashes.write(self.version, contract_address, class_hash);
state.class_hashes.write(self.tx_index, contract_address, class_hash);

Ok(())
}

fn increment_nonce(&mut self, contract_address: ContractAddress) -> StateResult<()> {
let mut state = self.state();
let current_nonce = state.nonces.read(self.version, contract_address).unwrap();
let current_nonce = state.nonces.read(self.tx_index, contract_address).unwrap();

let current_nonce_as_u64: u64 =
usize::try_from(current_nonce.0)?.try_into().expect("Failed to convert usize to u64.");
let next_nonce_val = 1_u64 + current_nonce_as_u64;
let next_nonce = Nonce(StarkFelt::from(next_nonce_val));
state.nonces.write(self.version, contract_address, next_nonce);
state.nonces.write(self.tx_index, contract_address, next_nonce);

Ok(())
}
Expand All @@ -184,7 +184,7 @@ impl<S: StateReader> State for VersionedStateProxy<S> {
compiled_class_hash: CompiledClassHash,
) -> StateResult<()> {
let mut state = self.state();
state.compiled_class_hashes.write(self.version, class_hash, compiled_class_hash);
state.compiled_class_hashes.write(self.tx_index, class_hash, compiled_class_hash);

Ok(())
}
Expand All @@ -195,7 +195,7 @@ impl<S: StateReader> State for VersionedStateProxy<S> {
contract_class: ContractClass,
) -> StateResult<()> {
let mut state = self.state();
state.compiled_contract_classes.write(self.version, class_hash, contract_class);
state.compiled_contract_classes.write(self.tx_index, class_hash, contract_class);

Ok(())
}
Expand All @@ -212,7 +212,7 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {
key: StorageKey,
) -> StateResult<StarkFelt> {
let mut state = self.state();
match state.storage.read(self.version, (contract_address, key)) {
match state.storage.read(self.tx_index, (contract_address, key)) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_storage_at(contract_address, key)?;
Expand All @@ -224,7 +224,7 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {

fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let mut state = self.state();
match state.nonces.read(self.version, contract_address) {
match state.nonces.read(self.tx_index, contract_address) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_nonce_at(contract_address)?;
Expand All @@ -236,7 +236,7 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {

fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let mut state = self.state();
match state.class_hashes.read(self.version, contract_address) {
match state.class_hashes.read(self.tx_index, contract_address) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_class_hash_at(contract_address)?;
Expand All @@ -248,7 +248,7 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {

fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
let mut state = self.state();
match state.compiled_class_hashes.read(self.version, class_hash) {
match state.compiled_class_hashes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_compiled_class_hash(class_hash)?;
Expand All @@ -260,7 +260,7 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {

fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
let mut state = self.state();
match state.compiled_contract_classes.read(self.version, class_hash) {
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_compiled_contract_class(class_hash)?;
Expand Down
12 changes: 6 additions & 6 deletions crates/blockifier/src/concurrency/versioned_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt::Debug;
use std::hash::Hash;

use crate::concurrency::Version;
use crate::concurrency::TxIndex;

#[cfg(test)]
#[path = "versioned_storage_test.rs"]
Expand All @@ -18,7 +18,7 @@ where
V: Clone + Debug,
{
cached_initial_values: HashMap<K, V>,
writes: HashMap<K, BTreeMap<Version, V>>,
writes: HashMap<K, BTreeMap<TxIndex, V>>,
}

impl<K, V> Default for VersionedStorage<K, V>
Expand All @@ -38,16 +38,16 @@ where
K: Clone + Copy + Eq + Hash + Debug,
V: Clone + Debug,
{
pub fn read(&self, version: Version, key: K) -> Option<V> {
pub fn read(&self, tx_index: TxIndex, key: K) -> Option<V> {
// Ignore the writes in the current transaction (may contain an `ESTIMATE` value). Reading
// the value written in this transaction should be handled by the state.
let value = self.writes.get(&key).and_then(|cell| cell.range(..version).next_back());
let value = self.writes.get(&key).and_then(|cell| cell.range(..tx_index).next_back());
value.map(|(_, value)| value).or_else(|| self.cached_initial_values.get(&key)).cloned()
}

pub fn write(&mut self, version: Version, key: K, value: V) {
pub fn write(&mut self, tx_index: TxIndex, key: K, value: V) {
let cell = self.writes.entry(key).or_default();
cell.insert(version, value);
cell.insert(tx_index, value);
}

/// This method inserts the provided key-value pair into the cached initial values map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn test_versioned_storage() {
// Read from the past.
storage.write(2, 10, 78);
assert_eq!(storage.read(1, 10).unwrap(), 31);
// Ignore the value written by the current version.
// Ignore the value written by the current transaction.
assert_eq!(storage.read(2, 10).unwrap(), 31);

// Read uninitialized cell.
Expand Down

0 comments on commit 0899890

Please sign in to comment.