Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KGA-26] [KGA-17] feat: add checks on Starknet return values for DualVMToken #1616

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cairo/token/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod starknet_token;
mod non_standard_starknet_token;
131 changes: 131 additions & 0 deletions cairo/token/src/non_standard_starknet_token.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//! A non-standard implementation of the ERC20 Token on Starknet.
//! Used for testing purposes with the DualVMToken.
//! Applied the following changes:
//! - `transfer` and `transfer_from` always return false
//! - `approve` always returns false
//! - `name`, `symbol` return a felt instead of a ByteArray

#[starknet::interface]
trait IERC20FeltMetadata<TState> {
fn name(self: @TState) -> felt252;
fn symbol(self: @TState) -> felt252;
fn decimals(self: @TState) -> u8;
}


#[starknet::contract]
mod NonStandardStarknetToken {
use openzeppelin::token::erc20::ERC20Component;
use openzeppelin::token::erc20::interface::{IERC20};
use super::IERC20FeltMetadata;
use starknet::ContractAddress;

component!(path: ERC20Component, storage: erc20, event: ERC20Event);

impl ERC20InternalImpl = ERC20Component::InternalImpl<ContractState>;

#[storage]
struct Storage {
#[substorage(v0)]
erc20: ERC20Component::Storage,
decimals: u8,
name: felt252,
symbol: felt252,
}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
#[flat]
ERC20Event: ERC20Component::Event
}

#[constructor]
fn constructor(
ref self: ContractState,
name: felt252,
symbol: felt252,
decimals: u8,
initial_supply: u256,
recipient: ContractAddress
) {
self._set_decimals(decimals);

// ERC20 initialization
self.name.write(name);
self.symbol.write(symbol);
self.erc20._mint(recipient, initial_supply);
}

#[external(v0)]
fn mint(ref self: ContractState, to: ContractAddress, amount: u256) {
self.erc20._mint(to, amount);
}

#[abi(embed_v0)]
impl ERC20MetadataImpl of IERC20FeltMetadata<ContractState> {
fn name(self: @ContractState) -> felt252 {
self.name.read()
}

fn symbol(self: @ContractState) -> felt252 {
self.symbol.read()
}

fn decimals(self: @ContractState) -> u8 {
self.decimals.read()
}
}

#[abi(embed_v0)]
impl ERC20 of IERC20<ContractState> {
/// Returns the value of tokens in existence.
fn total_supply(self: @ContractState) -> u256 {
self.erc20.ERC20_total_supply.read()
}

/// Returns the amount of tokens owned by `account`.
fn balance_of(self: @ContractState, account: ContractAddress) -> u256 {
self.erc20.ERC20_balances.read(account)
}

/// Returns the remaining number of tokens that `spender` is
/// allowed to spend on behalf of `owner` through `transfer_from`.
/// This is zero by default.
/// This value changes when `approve` or `transfer_from` are called.
fn allowance(
self: @ContractState, owner: ContractAddress, spender: ContractAddress
) -> u256 {
self.erc20.ERC20_allowances.read((owner, spender))
}


/// Modified to always return false
fn transfer(ref self: ContractState, recipient: ContractAddress, amount: u256) -> bool {
false
}


/// Modified to always return false
fn transfer_from(
ref self: ContractState,
sender: ContractAddress,
recipient: ContractAddress,
amount: u256
) -> bool {
false
}

/// Modified to always return false
fn approve(ref self: ContractState, spender: ContractAddress, amount: u256) -> bool {
false
}
}

#[generate_trait]
impl InternalImpl of InternalTrait {
fn _set_decimals(ref self: ContractState, decimals: u8) {
self.decimals.write(decimals);
}
}
}
9 changes: 8 additions & 1 deletion cairo/token/src/starknet_token.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ mod StarknetToken {
}

#[constructor]
fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray, decimals: u8, initial_supply: u256, recipient: ContractAddress) {
fn constructor(
ref self: ContractState,
name: ByteArray,
symbol: ByteArray,
decimals: u8,
initial_supply: u256,
recipient: ContractAddress
) {
self._set_decimals(decimals);

// ERC20 initialization
Expand Down
34 changes: 31 additions & 3 deletions kakarot_scripts/compile_kakarot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# %% Imports
import logging
import multiprocessing as mp
import re
from datetime import datetime

from kakarot_scripts.constants import COMPILED_CONTRACTS, DECLARED_CONTRACTS, NETWORK
from kakarot_scripts.constants import (
CAIRO_DIR,
COMPILED_CONTRACTS,
CONTRACTS,
DECLARED_CONTRACTS,
NETWORK,
)
from kakarot_scripts.utils.starknet import (
compile_contract,
compile_cairo_zero_contract,
compile_scarb_package,
compute_deployed_class_hash,
dump_class_hashes,
locate_scarb_root,
)

mp.set_start_method("fork")
Expand All @@ -22,8 +31,27 @@ def main():
# %% Compile
logger.info(f"ℹ️ Compiling contracts for network {NETWORK['name']}")
initial_time = datetime.now()

# Split contracts into Cairo 0 and Cairo 1 to avoid
# re-compiling the same package multiple times.
cairo0_contracts = []
cairo1_packages = set()
obatirou marked this conversation as resolved.
Show resolved Hide resolved

for contract in COMPILED_CONTRACTS:
contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get(
re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower()
)
if contract_path.is_relative_to(CAIRO_DIR):
cairo1_packages.add(locate_scarb_root(contract_path))
else:
cairo0_contracts.append(contract)

with mp.Pool() as pool:
pool.map(compile_contract, COMPILED_CONTRACTS)
cairo0_task = pool.map_async(compile_cairo_zero_contract, cairo0_contracts)
cairo1_task = pool.map_async(compile_scarb_package, cairo1_packages)

cairo0_task.wait()
cairo1_task.wait()
logger.info("ℹ️ Computing deployed class hashes")
with mp.Pool() as pool:
class_hashes = pool.map(compute_deployed_class_hash, DECLARED_CONTRACTS)
Expand Down
2 changes: 2 additions & 0 deletions kakarot_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class ChainId(IntEnum):
{"contract_name": "OpenzeppelinAccount", "is_account_contract": True},
{"contract_name": "replace_class", "is_account_contract": False},
{"contract_name": "StarknetToken", "is_account_contract": False},
{"contract_name": "NonStandardStarknetToken", "is_account_contract": False},
{"contract_name": "uninitialized_account_fixture", "is_account_contract": False},
{"contract_name": "uninitialized_account", "is_account_contract": False},
{"contract_name": "UniversalLibraryCaller", "is_account_contract": False},
Expand All @@ -244,6 +245,7 @@ class ChainId(IntEnum):
"OpenzeppelinAccount",
"replace_class",
"StarknetToken",
"NonStandardStarknetToken",
"uninitialized_account_fixture",
"uninitialized_account",
"UniversalLibraryCaller",
Expand Down
91 changes: 55 additions & 36 deletions kakarot_scripts/utils/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,17 +303,20 @@ def get_artifact(contract_name):
# Cairo 1 artifacts
artifacts = list(BUILD_DIR_SSJ.glob(f"**/*{contract_name}.*.json")) or [
artifact
for artifact in list(CAIRO_DIR.glob(f"**/*{contract_name}.*.json"))
for artifact in list(CAIRO_DIR.glob(f"**/*_{contract_name}.*.json"))
if "test" not in str(artifact)
]
if artifacts:
sierra, casm = (
artifacts
if "sierra.json" in artifacts[0].name
or ".contract_class.json" in artifacts[0].name
else artifacts[::-1]
)
return Artifact(sierra=sierra, casm=casm)
try:
sierra, casm = (
artifacts
if "sierra.json" in artifacts[0].name
or ".contract_class.json" in artifacts[0].name
else artifacts[::-1]
)
return Artifact(sierra=sierra, casm=casm)
except Exception as e:
raise FileNotFoundError(f"No artifact found for {contract_name}") from e

raise FileNotFoundError(f"No artifact found for {contract_name}")

Expand All @@ -336,40 +339,56 @@ def get_tx_url(tx_hash: int) -> str:
return f"{NETWORK['explorer_url']}/tx/0x{tx_hash:064x}"


def compile_contract(contract):
def locate_scarb_root(contract_path):
current_dir = contract_path.parent
while current_dir != current_dir.parent:
scarb_toml = current_dir / "Scarb.toml"
if scarb_toml.exists():
return current_dir
current_dir = current_dir.parent
return None


def compile_scarb_package(package_path):
logger.info(f"ℹ️ Compiling package {package_path}")
start = datetime.now()
output = subprocess.run(
"scarb build", shell=True, cwd=package_path, capture_output=True
)
if output.returncode != 0:
raise RuntimeError(
f"❌ {package_path} raised:\n{output.stderr}.\nOutput:\n{output.stdout}"
)

elapsed = datetime.now() - start
logger.info(f"✅ {package_path} compiled in {elapsed.total_seconds():.2f}s")


def compile_cairo_zero_contract(contract):
logger.info(f"⏳ Compiling {contract['contract_name']}")
start = datetime.now()
contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get(
re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower()
)

if contract_path.is_relative_to(CAIRO_DIR):
output = subprocess.run(
"scarb build", shell=True, cwd=contract_path.parent, capture_output=True
)
else:
output = subprocess.run(
[
"starknet-compile-deprecated",
contract_path,
"--output",
BUILD_DIR / f"{contract['contract_name']}.json",
"--cairo_path",
str(CAIRO_ZERO_DIR),
*(
["--no_debug_info"]
if NETWORK["type"] is not NetworkType.DEV
else []
),
*(["--account_contract"] if contract["is_account_contract"] else []),
*(
["--disable_hint_validation"]
if NETWORK["type"] is NetworkType.DEV
else []
),
],
capture_output=True,
)
output = subprocess.run(
[
"starknet-compile-deprecated",
contract_path,
"--output",
BUILD_DIR / f"{contract['contract_name']}.json",
"--cairo_path",
str(CAIRO_ZERO_DIR),
*(["--no_debug_info"] if NETWORK["type"] is not NetworkType.DEV else []),
*(["--account_contract"] if contract["is_account_contract"] else []),
*(
["--disable_hint_validation"]
if NETWORK["type"] is NetworkType.DEV
else []
),
],
capture_output=True,
)

if output.returncode != 0:
raise RuntimeError(
Expand Down
17 changes: 14 additions & 3 deletions solidity_contracts/src/CairoPrecompiles/DualVmToken.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ contract DualVmToken is NoDelegateCall {

/// @dev Emitted when an invalid starknet address is used
error InvalidStarknetAddress();
/// @dev Emitted when the return value of a starknet transfer is `false`.
error TransferFailed();
/// @dev Emitted when the return value of a starknet approval is `false`.
error ApprovalFailed();

/*//////////////////////////////////////////////////////////////
METADATA ACCESS
Expand Down Expand Up @@ -244,7 +248,10 @@ contract DualVmToken is NoDelegateCall {
approveCallData[1] = uint256(amountLow);
approveCallData[2] = uint256(amountHigh);

starknetToken.delegatecallCairo("approve", approveCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("approve", approveCallData), (bool));
if (!success) {
revert ApprovalFailed();
}
}

/// @dev Transfer tokens to an evm account
Expand Down Expand Up @@ -285,7 +292,10 @@ contract DualVmToken is NoDelegateCall {
transferCallData[1] = uint256(amountLow);
transferCallData[2] = uint256(amountHigh);

starknetToken.delegatecallCairo("transfer", transferCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("transfer", transferCallData), (bool));
if (!success) {
revert TransferFailed();
}
}

/// @dev Transfer tokens from one evm address to another
Expand Down Expand Up @@ -369,6 +379,7 @@ contract DualVmToken is NoDelegateCall {
transferFromCallData[2] = uint256(amountLow);
transferFromCallData[3] = uint256(amountHigh);

starknetToken.delegatecallCairo("transfer_from", transferFromCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("transfer_from", transferFromCallData), (bool));
require(success, "Transfer failed");
}
}
Loading
Loading