From 146d62d2d07df7307d00b126a2d13a39dfff40c9 Mon Sep 17 00:00:00 2001 From: enitrat Date: Mon, 18 Nov 2024 14:45:07 +0700 Subject: [PATCH 1/3] Add checks on Starknet return values for DualVMToken --- cairo/token/src/lib.cairo | 1 + .../src/non_standard_starknet_token.cairo | 128 ++++++++++++++++++ cairo/token/src/starknet_token.cairo | 1 + kakarot_scripts/compile_kakarot.py | 34 ++++- kakarot_scripts/constants.py | 2 + kakarot_scripts/utils/starknet.py | 91 ++++++++----- .../src/CairoPrecompiles/DualVmToken.sol | 9 +- .../CairoPrecompiles/test_dual_vm_token.py | 113 ++++++++++++++++ 8 files changed, 337 insertions(+), 42 deletions(-) create mode 100644 cairo/token/src/non_standard_starknet_token.cairo diff --git a/cairo/token/src/lib.cairo b/cairo/token/src/lib.cairo index 9b8f06e6a..8b8d86cbe 100644 --- a/cairo/token/src/lib.cairo +++ b/cairo/token/src/lib.cairo @@ -1 +1,2 @@ mod starknet_token; +mod non_standard_starknet_token; diff --git a/cairo/token/src/non_standard_starknet_token.cairo b/cairo/token/src/non_standard_starknet_token.cairo new file mode 100644 index 000000000..ed6bb4777 --- /dev/null +++ b/cairo/token/src/non_standard_starknet_token.cairo @@ -0,0 +1,128 @@ +//! A non-standard implementation of the ERC20 Token on Starknet. +//! Used for testing purposes with the DualVMToken. +//! Applied the following changes: +//! - `transfer` and `transfer_from` always return false +//! - `approve` always returns false +//! - `name`, `symbol` return a felt instead of a ByteArray + +#[starknet::interface] +trait IERC20FeltMetadata { + fn name(self: @TState) -> felt252; + fn symbol(self: @TState) -> felt252; + fn decimals(self: @TState) -> u8; +} + + +#[starknet::contract] +mod NonStandardStarknetToken { + use openzeppelin::token::erc20::ERC20Component; + use openzeppelin::token::erc20::interface::{IERC20}; + use super::IERC20FeltMetadata; + use starknet::ContractAddress; + + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + erc20: ERC20Component::Storage, + decimals: u8, + name: felt252, + symbol: felt252, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC20Event: ERC20Component::Event + } + + #[constructor] + fn constructor(ref self: ContractState, name: felt252, symbol: felt252, decimals: u8, initial_supply: u256, recipient: ContractAddress) { + self._set_decimals(decimals); + + // ERC20 initialization + self.name.write(name); + self.symbol.write(symbol); + self.erc20._mint(recipient, initial_supply); + } + + #[external(v0)] + fn mint(ref self: ContractState, to: ContractAddress, amount: u256) { + self.erc20._mint(to, amount); + } + + #[abi(embed_v0)] + impl ERC20MetadataImpl of IERC20FeltMetadata { + fn name(self: @ContractState) -> felt252 { + self.name.read() + } + + fn symbol(self: @ContractState) -> felt252 { + self.symbol.read() + } + + fn decimals(self: @ContractState) -> u8 { + self.decimals.read() + } + } + + #[abi(embed_v0)] + impl ERC20 of IERC20 { + /// Returns the value of tokens in existence. + fn total_supply(self: @ContractState) -> u256 { + self.erc20.ERC20_total_supply.read() + } + + /// Returns the amount of tokens owned by `account`. + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.erc20.ERC20_balances.read(account) + } + + /// Returns the remaining number of tokens that `spender` is + /// allowed to spend on behalf of `owner` through `transfer_from`. + /// This is zero by default. + /// This value changes when `approve` or `transfer_from` are called. + fn allowance( + self: @ContractState, owner: ContractAddress, spender: ContractAddress + ) -> u256 { + self.erc20.ERC20_allowances.read((owner, spender)) + } + + + /// Modified to always return false + fn transfer( + ref self: ContractState, recipient: ContractAddress, amount: u256 + ) -> bool { + false + } + + + /// Modified to always return false + fn transfer_from( + ref self: ContractState, + sender: ContractAddress, + recipient: ContractAddress, + amount: u256 + ) -> bool { + false + } + + /// Modified to always return false + fn approve( + ref self: ContractState, spender: ContractAddress, amount: u256 + ) -> bool { + false + } + } + + #[generate_trait] + impl InternalImpl of InternalTrait { + fn _set_decimals(ref self: ContractState, decimals: u8) { + self.decimals.write(decimals); + } + } +} diff --git a/cairo/token/src/starknet_token.cairo b/cairo/token/src/starknet_token.cairo index ddd4ee728..cb5d6dea7 100644 --- a/cairo/token/src/starknet_token.cairo +++ b/cairo/token/src/starknet_token.cairo @@ -61,4 +61,5 @@ mod StarknetToken { self.decimals.write(decimals); } } + } diff --git a/kakarot_scripts/compile_kakarot.py b/kakarot_scripts/compile_kakarot.py index 64b28e0ad..d89ca2cee 100644 --- a/kakarot_scripts/compile_kakarot.py +++ b/kakarot_scripts/compile_kakarot.py @@ -1,13 +1,22 @@ # %% Imports import logging import multiprocessing as mp +import re from datetime import datetime -from kakarot_scripts.constants import COMPILED_CONTRACTS, DECLARED_CONTRACTS, NETWORK +from kakarot_scripts.constants import ( + CAIRO_DIR, + COMPILED_CONTRACTS, + CONTRACTS, + DECLARED_CONTRACTS, + NETWORK, +) from kakarot_scripts.utils.starknet import ( - compile_contract, + compile_cairo_zero_contract, + compile_scarb_package, compute_deployed_class_hash, dump_class_hashes, + locate_scarb_root, ) mp.set_start_method("fork") @@ -22,8 +31,27 @@ def main(): # %% Compile logger.info(f"ℹ️ Compiling contracts for network {NETWORK['name']}") initial_time = datetime.now() + + # Split contracts into Cairo 0 and Cairo 1 to avoid + # re-compiling the same package multiple times. + cairo0_contracts = [] + cairo1_packages = set() + + for contract in COMPILED_CONTRACTS: + contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get( + re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower() + ) + if contract_path.is_relative_to(CAIRO_DIR): + cairo1_packages.add(locate_scarb_root(contract_path)) + else: + cairo0_contracts.append(contract) + with mp.Pool() as pool: - pool.map(compile_contract, COMPILED_CONTRACTS) + cairo0_task = pool.map_async(compile_cairo_zero_contract, cairo0_contracts) + cairo1_task = pool.map_async(compile_scarb_package, cairo1_packages) + + cairo0_task.wait() + cairo1_task.wait() logger.info("ℹ️ Computing deployed class hashes") with mp.Pool() as pool: class_hashes = pool.map(compute_deployed_class_hash, DECLARED_CONTRACTS) diff --git a/kakarot_scripts/constants.py b/kakarot_scripts/constants.py index c23d55268..8640bb6a5 100644 --- a/kakarot_scripts/constants.py +++ b/kakarot_scripts/constants.py @@ -224,6 +224,7 @@ class ChainId(IntEnum): {"contract_name": "OpenzeppelinAccount", "is_account_contract": True}, {"contract_name": "replace_class", "is_account_contract": False}, {"contract_name": "StarknetToken", "is_account_contract": False}, + {"contract_name": "NonStandardStarknetToken", "is_account_contract": False}, {"contract_name": "uninitialized_account_fixture", "is_account_contract": False}, {"contract_name": "uninitialized_account", "is_account_contract": False}, {"contract_name": "UniversalLibraryCaller", "is_account_contract": False}, @@ -244,6 +245,7 @@ class ChainId(IntEnum): "OpenzeppelinAccount", "replace_class", "StarknetToken", + "NonStandardStarknetToken", "uninitialized_account_fixture", "uninitialized_account", "UniversalLibraryCaller", diff --git a/kakarot_scripts/utils/starknet.py b/kakarot_scripts/utils/starknet.py index ac99c8392..9ebd46f9c 100644 --- a/kakarot_scripts/utils/starknet.py +++ b/kakarot_scripts/utils/starknet.py @@ -303,17 +303,20 @@ def get_artifact(contract_name): # Cairo 1 artifacts artifacts = list(BUILD_DIR_SSJ.glob(f"**/*{contract_name}.*.json")) or [ artifact - for artifact in list(CAIRO_DIR.glob(f"**/*{contract_name}.*.json")) + for artifact in list(CAIRO_DIR.glob(f"**/*_{contract_name}.*.json")) if "test" not in str(artifact) ] if artifacts: - sierra, casm = ( - artifacts - if "sierra.json" in artifacts[0].name - or ".contract_class.json" in artifacts[0].name - else artifacts[::-1] - ) - return Artifact(sierra=sierra, casm=casm) + try: + sierra, casm = ( + artifacts + if "sierra.json" in artifacts[0].name + or ".contract_class.json" in artifacts[0].name + else artifacts[::-1] + ) + return Artifact(sierra=sierra, casm=casm) + except Exception as e: + raise FileNotFoundError(f"No artifact found for {contract_name}") from e raise FileNotFoundError(f"No artifact found for {contract_name}") @@ -336,40 +339,56 @@ def get_tx_url(tx_hash: int) -> str: return f"{NETWORK['explorer_url']}/tx/0x{tx_hash:064x}" -def compile_contract(contract): +def locate_scarb_root(contract_path): + current_dir = contract_path.parent + while current_dir != current_dir.parent: + scarb_toml = current_dir / "Scarb.toml" + if scarb_toml.exists(): + return current_dir + current_dir = current_dir.parent + return None + + +def compile_scarb_package(package_path): + logger.info(f"ℹ️ Compiling package {package_path}") + start = datetime.now() + output = subprocess.run( + "scarb build", shell=True, cwd=package_path, capture_output=True + ) + if output.returncode != 0: + raise RuntimeError( + f"❌ {package_path} raised:\n{output.stderr}.\nOutput:\n{output.stdout}" + ) + + elapsed = datetime.now() - start + logger.info(f"✅ {package_path} compiled in {elapsed.total_seconds():.2f}s") + + +def compile_cairo_zero_contract(contract): logger.info(f"⏳ Compiling {contract['contract_name']}") start = datetime.now() contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get( re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower() ) - if contract_path.is_relative_to(CAIRO_DIR): - output = subprocess.run( - "scarb build", shell=True, cwd=contract_path.parent, capture_output=True - ) - else: - output = subprocess.run( - [ - "starknet-compile-deprecated", - contract_path, - "--output", - BUILD_DIR / f"{contract['contract_name']}.json", - "--cairo_path", - str(CAIRO_ZERO_DIR), - *( - ["--no_debug_info"] - if NETWORK["type"] is not NetworkType.DEV - else [] - ), - *(["--account_contract"] if contract["is_account_contract"] else []), - *( - ["--disable_hint_validation"] - if NETWORK["type"] is NetworkType.DEV - else [] - ), - ], - capture_output=True, - ) + output = subprocess.run( + [ + "starknet-compile-deprecated", + contract_path, + "--output", + BUILD_DIR / f"{contract['contract_name']}.json", + "--cairo_path", + str(CAIRO_ZERO_DIR), + *(["--no_debug_info"] if NETWORK["type"] is not NetworkType.DEV else []), + *(["--account_contract"] if contract["is_account_contract"] else []), + *( + ["--disable_hint_validation"] + if NETWORK["type"] is NetworkType.DEV + else [] + ), + ], + capture_output=True, + ) if output.returncode != 0: raise RuntimeError( diff --git a/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol b/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol index aa9a63a02..e37db2c5a 100644 --- a/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol +++ b/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol @@ -244,7 +244,8 @@ contract DualVmToken is NoDelegateCall { approveCallData[1] = uint256(amountLow); approveCallData[2] = uint256(amountHigh); - starknetToken.delegatecallCairo("approve", approveCallData); + bool success = abi.decode(starknetToken.delegatecallCairo("approve", approveCallData), (bool)); + require(success, "Approval failed"); } /// @dev Transfer tokens to an evm account @@ -285,7 +286,8 @@ contract DualVmToken is NoDelegateCall { transferCallData[1] = uint256(amountLow); transferCallData[2] = uint256(amountHigh); - starknetToken.delegatecallCairo("transfer", transferCallData); + bool success = abi.decode(starknetToken.delegatecallCairo("transfer", transferCallData), (bool)); + require(success, "Transfer failed"); } /// @dev Transfer tokens from one evm address to another @@ -369,6 +371,7 @@ contract DualVmToken is NoDelegateCall { transferFromCallData[2] = uint256(amountLow); transferFromCallData[3] = uint256(amountHigh); - starknetToken.delegatecallCairo("transfer_from", transferFromCallData); + bool success = abi.decode(starknetToken.delegatecallCairo("transfer_from", transferFromCallData), (bool)); + require(success, "Transfer failed"); } } diff --git a/tests/end_to_end/CairoPrecompiles/test_dual_vm_token.py b/tests/end_to_end/CairoPrecompiles/test_dual_vm_token.py index 936a4fe58..c2126f08f 100644 --- a/tests/end_to_end/CairoPrecompiles/test_dual_vm_token.py +++ b/tests/end_to_end/CairoPrecompiles/test_dual_vm_token.py @@ -41,6 +41,40 @@ async def dual_vm_token(kakarot, starknet_token, owner): return dual_vm_token +@pytest_asyncio.fixture(scope="package") +async def not_std_starknet_token(owner): + # A Non-Standard Starknet ERC20 token, that returns `false` on actions + # and has `felt252` return types for metadata. + address = await deploy_starknet( + "NonStandardStarknetToken", + "MyToken", + "MTK", + 18, + int(2**256 - 1), + owner.starknet_contract.address, + ) + return get_contract_starknet("NonStandardStarknetToken", address=address) + + +@pytest_asyncio.fixture(scope="package") +async def not_std_dual_vm_token(kakarot, not_std_starknet_token): + # A wrapper around a non-standard ERC20 to test for edge-case behaviors. + not_std_dual_vm_token = await deploy_kakarot( + "CairoPrecompiles", + "DualVmToken", + kakarot.address, + not_std_starknet_token.address, + ) + + await invoke( + "kakarot", + "set_authorized_cairo_precompile_caller", + int(not_std_dual_vm_token.address, 16), + True, + ) + return not_std_dual_vm_token + + @pytest.mark.asyncio(scope="package") @pytest.mark.CairoPrecompiles class TestDualVmToken: @@ -646,3 +680,82 @@ async def test_should_add_liquidity_and_swap( balance_other_before + amount_dual_vm_token_desired == balance_other_after ) + + +class TestNonStandardStarknetToken: + """ + Tests for DualVMToken wrapping a non-standard Starknet token + Covers cases where: + - The wrapped token returns `false` on `transfer`, `approval`, `transfer_from`. + - The metadata is returned using `felt252` instead of `ByteArray`. + """ + + class TestMetadata: + async def test_should_return_name( + self, not_std_starknet_token, not_std_dual_vm_token + ): + (name_starknet,) = await not_std_starknet_token.functions["name"].call() + name_evm_bytes = bytes(await not_std_dual_vm_token.name(), "UTF-8") + assert name_starknet == int.from_bytes(name_evm_bytes, "big") + + async def test_should_return_symbol( + self, not_std_starknet_token, not_std_dual_vm_token + ): + (symbol_starknet,) = await not_std_starknet_token.functions["symbol"].call() + symbol_evm_bytes = bytes(await not_std_dual_vm_token.symbol(), "UTF-8") + assert symbol_starknet == int.from_bytes(symbol_evm_bytes, "big") + + async def test_should_return_decimals( + self, not_std_starknet_token, not_std_dual_vm_token + ): + (decimals_starknet,) = await not_std_starknet_token.functions[ + "decimals" + ].call() + decimals_evm = await not_std_dual_vm_token.decimals() + assert decimals_starknet == decimals_evm + + class TestActions: + """ + The Starknet token returning false, the DualVMToken wrapper will fail the `require(success)` which + will revert the Starknet TX to avoid unwanted Starknet state changes. + """ + + async def test_transfer_should_fail_evm_tx(self, not_std_dual_vm_token, other): + with cairo_error( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles" + ): + success = ( + await not_std_dual_vm_token.functions["transfer(address,uint256)"]( + other.address, 1 + ) + )["success"] + assert success == 0 + + async def test_approve_should_fail_evm_tx(self, not_std_dual_vm_token, other): + with cairo_error( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles" + ): + success = ( + await not_std_dual_vm_token.functions["approve(address,uint256)"]( + other.address, 1 + ) + )["success"] + assert success == 0 + + async def test_transfer_from_should_fail_evm_tx( + self, not_std_dual_vm_token, owner, other + ): + with cairo_error( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles" + ): + success = ( + await not_std_dual_vm_token.functions[ + "transferFrom(address,address,uint256)" + ]( + owner.address, + other.address, + 1, + caller_eoa=other.starknet_contract, + ) + )["success"] + assert success == 0 From f74cfb8bc5ebd70f7c3c081332f45d169230e068 Mon Sep 17 00:00:00 2001 From: enitrat Date: Mon, 18 Nov 2024 18:11:21 +0800 Subject: [PATCH 2/3] fmt --- .../src/non_standard_starknet_token.cairo | 95 ++++++++++--------- cairo/token/src/starknet_token.cairo | 10 +- 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/cairo/token/src/non_standard_starknet_token.cairo b/cairo/token/src/non_standard_starknet_token.cairo index ed6bb4777..85715c739 100644 --- a/cairo/token/src/non_standard_starknet_token.cairo +++ b/cairo/token/src/non_standard_starknet_token.cairo @@ -41,7 +41,14 @@ mod NonStandardStarknetToken { } #[constructor] - fn constructor(ref self: ContractState, name: felt252, symbol: felt252, decimals: u8, initial_supply: u256, recipient: ContractAddress) { + fn constructor( + ref self: ContractState, + name: felt252, + symbol: felt252, + decimals: u8, + initial_supply: u256, + recipient: ContractAddress + ) { self._set_decimals(decimals); // ERC20 initialization @@ -72,51 +79,47 @@ mod NonStandardStarknetToken { #[abi(embed_v0)] impl ERC20 of IERC20 { - /// Returns the value of tokens in existence. - fn total_supply(self: @ContractState) -> u256 { - self.erc20.ERC20_total_supply.read() - } - - /// Returns the amount of tokens owned by `account`. - fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { - self.erc20.ERC20_balances.read(account) - } - - /// Returns the remaining number of tokens that `spender` is - /// allowed to spend on behalf of `owner` through `transfer_from`. - /// This is zero by default. - /// This value changes when `approve` or `transfer_from` are called. - fn allowance( - self: @ContractState, owner: ContractAddress, spender: ContractAddress - ) -> u256 { - self.erc20.ERC20_allowances.read((owner, spender)) - } - - - /// Modified to always return false - fn transfer( - ref self: ContractState, recipient: ContractAddress, amount: u256 - ) -> bool { - false - } - - - /// Modified to always return false - fn transfer_from( - ref self: ContractState, - sender: ContractAddress, - recipient: ContractAddress, - amount: u256 - ) -> bool { - false - } - - /// Modified to always return false - fn approve( - ref self: ContractState, spender: ContractAddress, amount: u256 - ) -> bool { - false - } + /// Returns the value of tokens in existence. + fn total_supply(self: @ContractState) -> u256 { + self.erc20.ERC20_total_supply.read() + } + + /// Returns the amount of tokens owned by `account`. + fn balance_of(self: @ContractState, account: ContractAddress) -> u256 { + self.erc20.ERC20_balances.read(account) + } + + /// Returns the remaining number of tokens that `spender` is + /// allowed to spend on behalf of `owner` through `transfer_from`. + /// This is zero by default. + /// This value changes when `approve` or `transfer_from` are called. + fn allowance( + self: @ContractState, owner: ContractAddress, spender: ContractAddress + ) -> u256 { + self.erc20.ERC20_allowances.read((owner, spender)) + } + + + /// Modified to always return false + fn transfer(ref self: ContractState, recipient: ContractAddress, amount: u256) -> bool { + false + } + + + /// Modified to always return false + fn transfer_from( + ref self: ContractState, + sender: ContractAddress, + recipient: ContractAddress, + amount: u256 + ) -> bool { + false + } + + /// Modified to always return false + fn approve(ref self: ContractState, spender: ContractAddress, amount: u256) -> bool { + false + } } #[generate_trait] diff --git a/cairo/token/src/starknet_token.cairo b/cairo/token/src/starknet_token.cairo index cb5d6dea7..c06c7c1d5 100644 --- a/cairo/token/src/starknet_token.cairo +++ b/cairo/token/src/starknet_token.cairo @@ -27,7 +27,14 @@ mod StarknetToken { } #[constructor] - fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray, decimals: u8, initial_supply: u256, recipient: ContractAddress) { + fn constructor( + ref self: ContractState, + name: ByteArray, + symbol: ByteArray, + decimals: u8, + initial_supply: u256, + recipient: ContractAddress + ) { self._set_decimals(decimals); // ERC20 initialization @@ -61,5 +68,4 @@ mod StarknetToken { self.decimals.write(decimals); } } - } From 7bb2860fa206ce30e1a06730471a61e584c253e3 Mon Sep 17 00:00:00 2001 From: enitrat Date: Mon, 18 Nov 2024 21:42:31 +0800 Subject: [PATCH 3/3] use solidity errors --- .../src/CairoPrecompiles/DualVmToken.sol | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol b/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol index e37db2c5a..0261fa400 100644 --- a/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol +++ b/solidity_contracts/src/CairoPrecompiles/DualVmToken.sol @@ -59,6 +59,10 @@ contract DualVmToken is NoDelegateCall { /// @dev Emitted when an invalid starknet address is used error InvalidStarknetAddress(); + /// @dev Emitted when the return value of a starknet transfer is `false`. + error TransferFailed(); + /// @dev Emitted when the return value of a starknet approval is `false`. + error ApprovalFailed(); /*////////////////////////////////////////////////////////////// METADATA ACCESS @@ -245,7 +249,9 @@ contract DualVmToken is NoDelegateCall { approveCallData[2] = uint256(amountHigh); bool success = abi.decode(starknetToken.delegatecallCairo("approve", approveCallData), (bool)); - require(success, "Approval failed"); + if (!success) { + revert ApprovalFailed(); + } } /// @dev Transfer tokens to an evm account @@ -287,7 +293,9 @@ contract DualVmToken is NoDelegateCall { transferCallData[2] = uint256(amountHigh); bool success = abi.decode(starknetToken.delegatecallCairo("transfer", transferCallData), (bool)); - require(success, "Transfer failed"); + if (!success) { + revert TransferFailed(); + } } /// @dev Transfer tokens from one evm address to another