Skip to content

Commit

Permalink
Add mock_call (#385)
Browse files Browse the repository at this point in the history
<!-- Reference any GitHub issues resolved by this PR -->

Closes #

## Introduced changes

<!-- A brief description of the changes -->

- Add `start_mock_call` cheatcode (with tests and docs)
- Add `stop_mock_call` cheatcode (with tests and docs)

## Breaking changes

<!-- List of all breaking changes, if applicable -->

## Checklist

<!-- Make sure all of these are complete -->

- [x] Linked relevant issue
- [x] Updated relevant documentation
- [x] Added relevant tests
- [x] Performed self-review of the code
- [x] Added changes to `CHANGELOG.md`

---------

Co-authored-by: Marcin Warchoł <[email protected]>
Co-authored-by: Maksymilian Demitraszek <[email protected]>
Co-authored-by: Piotr Magiera <[email protected]>
  • Loading branch information
4 people authored Aug 16, 2023
1 parent 1a6bddc commit 1981e86
Show file tree
Hide file tree
Showing 18 changed files with 1,172 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Print support for basic numeric data types
- Functions `parse_txt` and `TxtParser<T>::deserialize_txt` to load data from plain text files and serialize it
- `get_class_hash` cheatcode
- `mock_call` cheatcode

#### Changed

Expand Down
1 change: 1 addition & 0 deletions crates/cheatnet/src/cheatcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use thiserror::Error;
pub mod declare;
pub mod deploy;
pub mod get_class_hash;
pub mod mock_call;
pub mod prank;
pub mod roll;
pub mod warp;
Expand Down
36 changes: 36 additions & 0 deletions crates/cheatnet/src/cheatcodes/mock_call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::CheatnetState;
use starknet_api::core::{ContractAddress, EntryPointSelector};
use starknet_api::hash::StarkFelt;
use std::collections::HashMap;

impl CheatnetState {
pub fn start_mock_call(
&mut self,
contract_address: ContractAddress,
function_name: EntryPointSelector,
ret_data: Vec<StarkFelt>,
) {
let contract_mocked_functions = self
.cheatcode_state
.mocked_functions
.entry(contract_address)
.or_insert_with(HashMap::new);

contract_mocked_functions.insert(function_name, ret_data);
}

pub fn stop_mock_call(
&mut self,
contract_address: ContractAddress,
function_name: EntryPointSelector,
) {
if let std::collections::hash_map::Entry::Occupied(mut e) = self
.cheatcode_state
.mocked_functions
.entry(contract_address)
{
let contract_mocked_functions = e.get_mut();
contract_mocked_functions.remove(&function_name);
}
}
}
36 changes: 36 additions & 0 deletions crates/cheatnet/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ use cairo_vm::{
vm_core::VirtualMachine,
},
};
use std::collections::HashSet;
use std::{any::Any, collections::HashMap, sync::Arc};

use crate::{
constants::{build_block_context, build_transaction_context, TEST_ACCOUNT_CONTRACT_ADDRESS},
CheatnetState,
};
use blockifier::execution::entry_point::CallExecution;
use blockifier::execution::entry_point::Retdata;
use blockifier::{
abi::constants,
execution::{
Expand Down Expand Up @@ -56,6 +59,7 @@ use cairo_lang_casm::{
hints::{Hint, StarknetHint},
operand::{BinOpOperand, DerefOrImmediate, Operation, Register, ResOperand},
};
use cairo_vm::vm::runners::cairo_runner::ExecutionResources as VmExecutionResources;
use starknet_api::{
core::{ClassHash, ContractAddress, EntryPointSelector, PatriciaKey},
deprecated_contract_class::EntryPointType,
Expand Down Expand Up @@ -693,6 +697,21 @@ pub fn execute_inner_call(
Ok(retdata_segment)
}

fn get_ret_data_by_call_entry_point<'a>(
call: &CallEntryPoint,
cheatcode_state: &'a CheatcodeState,
) -> Option<&'a Vec<StarkFelt>> {
if let Some(contract_address) = call.code_address {
if let Some(contract_functions) = cheatcode_state.mocked_functions.get(&contract_address) {
let entrypoint_selector = call.entry_point_selector;

let ret_data = contract_functions.get(&entrypoint_selector);
return ret_data;
}
}
None
}

/// Executes a specific call to a contract entry point and returns its output.
fn execute_entry_point_call_cairo1(
call: CallEntryPoint,
Expand All @@ -702,6 +721,23 @@ fn execute_entry_point_call_cairo1(
resources: &mut ExecutionResources,
context: &mut EntryPointExecutionContext,
) -> EntryPointExecutionResult<CallInfo> {
if let Some(ret_data) = get_ret_data_by_call_entry_point(&call, cheatcode_state) {
return Ok(CallInfo {
call,
execution: CallExecution {
retdata: Retdata(ret_data.clone()),
events: vec![],
l2_to_l1_messages: vec![],
failed: false,
gas_consumed: 0,
},
vm_resources: VmExecutionResources::default(),
inner_calls: vec![],
storage_read_values: vec![],
accessed_storage_keys: HashSet::new(),
});
}

let VmExecutionContext {
mut runner,
mut vm,
Expand Down
3 changes: 3 additions & 0 deletions crates/cheatnet/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use blockifier::{
},
};
use cairo_felt::Felt252;
use starknet_api::core::EntryPointSelector;
use starknet_api::{
core::{ClassHash, CompiledClassHash, ContractAddress, Nonce},
hash::StarkFelt,
Expand Down Expand Up @@ -86,6 +87,7 @@ pub struct CheatcodeState {
pub rolled_contracts: HashMap<ContractAddress, Felt252>,
pub pranked_contracts: HashMap<ContractAddress, ContractAddress>,
pub warped_contracts: HashMap<ContractAddress, Felt252>,
pub mocked_functions: HashMap<ContractAddress, HashMap<EntryPointSelector, Vec<StarkFelt>>>,
}

impl CheatcodeState {
Expand All @@ -95,6 +97,7 @@ impl CheatcodeState {
rolled_contracts: HashMap::new(),
pranked_contracts: HashMap::new(),
warped_contracts: HashMap::new(),
mocked_functions: HashMap::new(),
}
}
}
Expand Down
39 changes: 37 additions & 2 deletions crates/forge/src/cheatcodes_hint_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::path::PathBuf;

use crate::scarb::StarknetContractArtifacts;
use anyhow::{anyhow, Result};
use blockifier::abi::abi_utils::selector_from_name;
use blockifier::execution::execution_utils::{felt_to_stark_felt, stark_felt_to_felt};

use cairo_felt::Felt252;
use cairo_vm::hint_processor::hint_processor_definition::HintProcessorLogic;
use cairo_vm::hint_processor::hint_processor_definition::HintReference;
Expand Down Expand Up @@ -226,7 +226,42 @@ impl CairoHintProcessor<'_> {
self.cheatnet_state.stop_prank(contract_address);
Ok(())
}
"mock_call" => todo!(),
"start_mock_call" => {
let contract_address = ContractAddress(PatriciaKey::try_from(StarkFelt::new(
inputs[0].clone().to_be_bytes(),
)?)?);
let function_name = inputs[1].clone();
let function_name = as_cairo_short_string(&function_name).unwrap_or_else(|| {
panic!("Failed to convert {function_name:?} to Cairo short str")
});
let function_name = selector_from_name(function_name.as_str());

let ret_data_length = inputs[2]
.to_usize()
.expect("Missing ret_data len in inputs");
let mut ret_data = vec![];
for felt in inputs.iter().skip(3).take(ret_data_length) {
ret_data.push(felt_to_stark_felt(&felt.clone()));
}

self.cheatnet_state
.start_mock_call(contract_address, function_name, ret_data);
Ok(())
}
"stop_mock_call" => {
let contract_address = ContractAddress(PatriciaKey::try_from(StarkFelt::new(
inputs[0].clone().to_be_bytes(),
)?)?);
let function_name = inputs[1].clone();
let function_name = as_cairo_short_string(&function_name).unwrap_or_else(|| {
panic!("Failed to convert {function_name:?} to Cairo short str")
});
let function_name = selector_from_name(function_name.as_str());

self.cheatnet_state
.stop_mock_call(contract_address, function_name);
Ok(())
}
"declare" => {
let contract_name = inputs[0].clone();

Expand Down
34 changes: 34 additions & 0 deletions crates/forge/tests/data/contracts/constructor_mock_checker.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use starknet::ContractAddress;

#[starknet::interface]
trait IConstructorMockChecker<TContractState> {
fn get_stored_thing(ref self: TContractState) -> felt252;
fn get_constant_thing(ref self: TContractState) -> felt252;
}

#[starknet::contract]
mod ConstructorMockChecker {
use super::IConstructorMockChecker;

#[storage]
struct Storage {
stored_thing: felt252,
}

#[constructor]
fn constructor(ref self: ContractState) {
let const_thing = self.get_constant_thing();
self.stored_thing.write(const_thing);
}

#[external(v0)]
impl IConstructorMockCheckerImpl of super::IConstructorMockChecker<ContractState> {
fn get_constant_thing(ref self: ContractState) -> felt252 {
13
}

fn get_stored_thing(ref self: ContractState) -> felt252 {
self.stored_thing.read()
}
}
}
54 changes: 54 additions & 0 deletions crates/forge/tests/data/contracts/mock_checker.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#[derive(Serde, Drop)]
struct StructThing {
item_one: felt252,
item_two: felt252,
}

#[starknet::interface]
trait IMockChecker<TContractState> {
fn get_thing(ref self: TContractState) -> felt252;
fn get_thing_wrapper(ref self: TContractState) -> felt252;
fn get_constant_thing(ref self: TContractState) -> felt252;
fn get_struct_thing(ref self: TContractState) -> StructThing;
fn get_arr_thing(ref self: TContractState) -> Array<StructThing>;
}

#[starknet::contract]
mod MockChecker {
use super::IMockChecker;
use super::StructThing;
use array::ArrayTrait;

#[storage]
struct Storage {
stored_thing: felt252
}

#[constructor]
fn constructor(ref self: ContractState, arg1: felt252) {
self.stored_thing.write(arg1)
}

#[external(v0)]
impl IMockCheckerImpl of super::IMockChecker<ContractState> {
fn get_thing(ref self: ContractState) -> felt252 {
self.stored_thing.read()
}

fn get_thing_wrapper(ref self: ContractState) -> felt252 {
self.get_thing()
}

fn get_constant_thing(ref self: ContractState) -> felt252 {
13
}

fn get_struct_thing(ref self: ContractState) -> StructThing {
StructThing {item_one: 12, item_two: 21}
}

fn get_arr_thing(ref self: ContractState) -> Array<StructThing> {
array![StructThing {item_one: 12, item_two: 21}]
}
}
}
28 changes: 28 additions & 0 deletions crates/forge/tests/data/contracts/mock_checker_library_call.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use starknet::ClassHash;

#[starknet::interface]
trait IMockChecker<TContractState> {
fn get_constant_thing(ref self: TContractState) -> felt252;
}

#[starknet::interface]
trait IMockCheckerLibCall<TContractState> {
fn get_constant_thing_with_lib_call(ref self: TContractState, class_hash: ClassHash) -> felt252;
}

#[starknet::contract]
mod MockCheckerLibCall {
use super::{IMockCheckerDispatcherTrait, IMockCheckerLibraryDispatcher};
use starknet::ClassHash;

#[storage]
struct Storage {}

#[external(v0)]
impl IMockCheckerLibCall of super::IMockCheckerLibCall<ContractState> {
fn get_constant_thing_with_lib_call(ref self: ContractState, class_hash: ClassHash) -> felt252 {
let mock_checker = IMockCheckerLibraryDispatcher { class_hash };
mock_checker.get_constant_thing()
}
}
}
49 changes: 49 additions & 0 deletions crates/forge/tests/data/contracts/mock_checker_proxy.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use starknet::ContractAddress;

#[starknet::interface]
trait IMockChecker<TContractState> {
fn get_thing(ref self: TContractState) -> felt252;
}


#[starknet::interface]
trait IMockCheckerProxy<TContractState> {
fn get_thing_from_contract(ref self: TContractState, address: ContractAddress) -> felt252;
fn get_thing_from_contract_and_emit_event(ref self: TContractState, address: ContractAddress) -> felt252;
}

#[starknet::contract]
mod MockCheckerProxy {
use starknet::ContractAddress;
use super::IMockCheckerDispatcherTrait;
use super::IMockCheckerDispatcher;

#[storage]
struct Storage {}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
ThingEmitted: ThingEmitted
}

#[derive(Drop, starknet::Event)]
struct ThingEmitted {
thing: felt252
}

#[external(v0)]
impl IMockCheckerProxy of super::IMockCheckerProxy<ContractState> {
fn get_thing_from_contract(ref self: ContractState, address: ContractAddress) -> felt252 {
let dispatcher = IMockCheckerDispatcher { contract_address: address };
dispatcher.get_thing()
}

fn get_thing_from_contract_and_emit_event(ref self: ContractState, address: ContractAddress) -> felt252 {
let dispatcher = IMockCheckerDispatcher { contract_address: address };
let thing = dispatcher.get_thing();
self.emit(Event::ThingEmitted(ThingEmitted { thing }));
thing
}
}
}
Loading

0 comments on commit 1981e86

Please sign in to comment.