diff --git a/kakarot_scripts/utils/kakarot.py b/kakarot_scripts/utils/kakarot.py index ef75ad2d0..26eb5ba6a 100644 --- a/kakarot_scripts/utils/kakarot.py +++ b/kakarot_scripts/utils/kakarot.py @@ -44,8 +44,7 @@ ChainId, ) from kakarot_scripts.data.pre_eip155_txs import PRE_EIP155_TX -from kakarot_scripts.utils.relayers import RelayerPool -from kakarot_scripts.utils.starknet import _max_fee +from kakarot_scripts.utils.starknet import RelayerPool, _max_fee from kakarot_scripts.utils.starknet import call from kakarot_scripts.utils.starknet import call as _call_starknet from kakarot_scripts.utils.starknet import fund_address as _fund_starknet_address diff --git a/kakarot_scripts/utils/relayers.py b/kakarot_scripts/utils/relayers.py deleted file mode 100644 index aa9695477..000000000 --- a/kakarot_scripts/utils/relayers.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging - -from async_lru import alru_cache -from starknet_py.net.account.account import Account - -from kakarot_scripts.constants import NETWORK -from kakarot_scripts.utils.starknet import ( - deploy_starknet_account, - get_eth_contract, - get_starknet_account, - invoke, -) - -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class RelayerPool: - _cached_relayers = None - - def __init__(self, accounts): - self.relayer_accounts = accounts - self.index = 0 - - @classmethod - @alru_cache - async def create(cls, n, **kwargs): - logger.info(f"ℹ️ Creating {n} relayer accounts") - accounts = [] - for i in range(n): - receipt = await deploy_starknet_account( - salt=i + int(NETWORK["account_address"], 16), **kwargs - ) - account = await get_starknet_account(address=receipt["address"]) - accounts.append(account) - logger.info(f"✅ Created {n} relayer accounts") - return cls(accounts) - - def __next__(self) -> Account: - relayer = self.relayer_accounts[self.index] - self.index = (self.index + 1) % len(self.relayer_accounts) - return relayer - - @classmethod - @alru_cache - async def default(cls, **kwargs): - return await cls.create(NETWORK.get("relayers", 20), **kwargs) - - @classmethod - @alru_cache - async def get(cls, salt: int): - if cls._cached_relayers is None: - cls._cached_relayers = await cls.default() - - return cls._cached_relayers.relayer_accounts[ - salt % len(cls._cached_relayers.relayer_accounts) - ] - - async def balances(self): - eth_contract = await get_eth_contract() - return [ - ( - f"0x{relayer.address:064x}", - f"{(await eth_contract.functions['balanceOf'].call(relayer.address)).balance / 1e18:.2f} ETH", - ) - for relayer in self.relayer_accounts - ] - - async def withdraw_all(self, to: int = int(NETWORK["account_address"], 16)): - eth_contract = await get_eth_contract() - for relayer in self.relayer_accounts: - balance = ( - await eth_contract.functions["balanceOf"].call(relayer.address) - ).balance - if balance > 0: - await invoke( - "ERC20", - "transfer", - to, - balance, - account=relayer, - address=eth_contract.address, - ) diff --git a/kakarot_scripts/utils/starknet.py b/kakarot_scripts/utils/starknet.py index fa3cb4f86..a27a3f02c 100644 --- a/kakarot_scripts/utils/starknet.py +++ b/kakarot_scripts/utils/starknet.py @@ -71,7 +71,6 @@ _logs = defaultdict(list) _lazy_execute = defaultdict(bool) _multisig_account = defaultdict(bool) -_single_sig_account = None # Dict to store selector to name mapping because argent api requires the name but calls have selector _selector_to_name = {get_selector_from_name("deployContract"): "deployContract"} @@ -96,8 +95,6 @@ async def get_starknet_account( address=None, private_key=None, ) -> Account: - global _single_sig_account - address = address or NETWORK["account_address"] if address is None: raise ValueError( @@ -152,17 +149,7 @@ async def get_starknet_account( ) if len(public_keys) > 1: register_multisig_account(address) - logger.info( - "ℹ️ Account is a multisig: " - "deploying a regular account for declarations" - "using the same private key and the multisig address as salt" - ) - receipt = await deploy_starknet_account( - private_key=private_key, salt=address - ) - _single_sig_account = await get_starknet_account( - address=receipt["address"], private_key=private_key - ) + logger.info("ℹ️ Account is a multisig") else: logger.warning( @@ -484,7 +471,7 @@ async def declare(contract_name): account = await get_starknet_account() if _multisig_account[account.address]: - account = _single_sig_account + account = await RelayerPool.get(account.address) if artifact.sierra is not None: casm_compiled_contract = artifact.casm.read_text() @@ -795,3 +782,71 @@ async def get_class_hash_at(address): return await RPC_CLIENT.get_class_hash_at(address) except Exception: return None + + +class RelayerPool: + _cached_relayers = None + + def __init__(self, accounts): + self.relayer_accounts = accounts + self.index = 0 + + @classmethod + @alru_cache + async def create(cls, n, **kwargs): + logger.info(f"ℹ️ Creating {n} relayer accounts") + accounts = [] + for i in range(n): + receipt = await deploy_starknet_account( + salt=i + int(NETWORK["account_address"], 16), **kwargs + ) + account = await get_starknet_account(address=receipt["address"]) + accounts.append(account) + logger.info(f"✅ Created {n} relayer accounts") + return cls(accounts) + + def __next__(self) -> Account: + relayer = self.relayer_accounts[self.index] + self.index = (self.index + 1) % len(self.relayer_accounts) + return relayer + + @classmethod + @alru_cache + async def default(cls, **kwargs): + return await cls.create(NETWORK.get("relayers", 20), **kwargs) + + @classmethod + @alru_cache + async def get(cls, salt: int): + if cls._cached_relayers is None: + cls._cached_relayers = await cls.default() + + return cls._cached_relayers.relayer_accounts[ + salt % len(cls._cached_relayers.relayer_accounts) + ] + + async def balances(self): + eth_contract = await get_eth_contract() + return [ + ( + f"0x{relayer.address:064x}", + f"{(await eth_contract.functions['balanceOf'].call(relayer.address)).balance / 1e18:.2f} ETH", + ) + for relayer in self.relayer_accounts + ] + + async def withdraw_all(self, to: int = int(NETWORK["account_address"], 16)): + eth_contract = await get_eth_contract() + for relayer in self.relayer_accounts: + balance = ( + await eth_contract.functions["balanceOf"].call(relayer.address) + ).balance + if balance > 0: + await invoke( + "ERC20", + "transfer", + to, + balance, + account=relayer, + address=eth_contract.address, + )