Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get missing metadata contracts task #58

Merged
merged 7 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class Settings(BaseSettings):
ETHERSCAN_MAX_REQUESTS: int = 1
BLOCKSCOUT_MAX_REQUESTS: int = 1
SOURCIFY_MAX_REQUESTS: int = 100
CONTRACT_MAX_DOWNLOAD_RETRIES: int = (
90 # Task running once per day, means 3 months trying.
)


settings = Settings()
Expand Down
21 changes: 21 additions & 0 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,24 @@ async def get_abi_by_contract_address(
if result := results.first():
return cast(ABI, result)
return None

@classmethod
async def get_contracts_without_abi(
cls, session: AsyncSession, max_retries: int = 0
):
"""
Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets.
More information about streaming results can be found here: https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results

:param session:
:param max_retries:
:return:
"""
query = (
select(cls)
.where(cls.abi_id == None) # noqa: E711
.where(cls.fetch_retries <= max_retries)
)
result = await session.stream(query)
async for contract in result:
yield contract
10 changes: 6 additions & 4 deletions app/services/contract_metadata_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def should_attempt_download(
session: AsyncSession,
contract_address: ChecksumAddress,
chain_id: int,
retries: int,
max_retries: int,
) -> bool:
"""
Return True if fetch retries is less than the number of retries and there is not ABI, False otherwise.
Expand All @@ -226,11 +226,13 @@ async def should_attempt_download(
:param session:
:param contract_address:
:param chain_id:
:param retries:
:param max_retries:
:return:
"""
redis = get_redis()
cache_key = f"should_attempt_download:{contract_address}:{chain_id}:{retries}"
cache_key = (
f"should_attempt_download:{contract_address}:{chain_id}:{max_retries}"
)
# Try from cache first
cached_retries = cast(str, redis.get(cache_key))
if cached_retries:
Expand All @@ -240,7 +242,7 @@ async def should_attempt_download(
session, address=HexBytes(contract_address), chain_id=chain_id
)

if contract and (contract.fetch_retries > retries or contract.abi_id):
if contract and (contract.fetch_retries > max_retries or contract.abi_id):
redis.set(cache_key, 0)
return False

Expand Down
4 changes: 3 additions & 1 deletion app/services/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def process_event(self, message: str) -> None:
if self._is_processable_event(tx_service_event):
chain_id = int(tx_service_event["chainId"])
contract_address = tx_service_event["to"]
get_contract_metadata_task.send(contract_address, chain_id)
get_contract_metadata_task.send(
address=contract_address, chain_id=chain_id
)
except json.JSONDecodeError:
logging.error(f"Unsupported message. Cannot parse as JSON: {message}")

Expand Down
29 changes: 29 additions & 0 deletions app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from eth_account import Account
from hexbytes import HexBytes
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
Expand Down Expand Up @@ -134,3 +136,30 @@ async def test_timestamped_model(self, session: AsyncSession):
self.assertEqual(result_updated[0].created, contract_created_date)
self.assertNotEqual(result_updated[0].modified, contract_modified_date)
self.assertTrue(contract_modified_date < result_updated[0].modified)

@database_session
async def test_get_contracts_without_abi(self, session: AsyncSession):
random_address = HexBytes(Account.create().address)
abi_json = {"name": "A Test ABI"}
source = AbiSource(name="local", url="")
await source.create(session)
abi = Abi(abi_json=abi_json, source_id=source.id)
await abi.create(session)
# Should return the contract
expected_contract = await Contract(
address=random_address, name="A test contract", chain_id=1
).create(session)
async for contract in Contract.get_contracts_without_abi(session, 0):
self.assertEqual(expected_contract, contract[0])

# Contracts with more retries shouldn't be returned
expected_contract.fetch_retries = 1
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 0):
self.fail("Expected no contracts, but found one.")

# Contracts with abi shouldn't be returned
expected_contract.abi_id = abi.id
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 10):
self.fail("Expected no contracts, but found one.")
2 changes: 1 addition & 1 deletion app/tests/services/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def test_process_event_calls_send(self, mock_get_contract_metadata_task):
EventsService().process_event(valid_message)

mock_get_contract_metadata_task.assert_called_once_with(
"0x6ED857dc1da2c41470A95589bB482152000773e9", 1
address="0x6ED857dc1da2c41470A95589bB482152000773e9", chain_id=1
)
48 changes: 44 additions & 4 deletions app/tests/workers/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from dramatiq.worker import Worker
from eth_account import Account
from hexbytes import HexBytes
from safe_eth.eth import EthereumNetwork
from safe_eth.eth.clients import AsyncEtherscanClientV2
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.datasources.db.models import AbiSource, Contract
from app.workers.tasks import get_contract_metadata_task, redis_broker, test_task

from ...datasources.cache.redis import get_redis
from ...services.contract_metadata_service import ContractMetadataService
from ..datasources.db.db_async_conn import DbAsyncConn
from ..mocks.contract_metadata_mocks import (
etherscan_metadata_mock,
Expand Down Expand Up @@ -75,23 +78,60 @@ def _wait_tasks_execution(self):
while len(redis_tasks) > 0:
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)

@mock.patch.object(ContractMetadataService, "enabled_clients")
@mock.patch.object(
AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True
)
@database_session
async def test_get_contract_metadata_task(
self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession
self,
etherscan_get_contract_metadata_mock: MagicMock,
mock_enabled_clients: MagicMock,
session: AsyncSession,
):
etherscan_get_contract_metadata_mock.return_value = etherscan_metadata_mock
contract_address = "0xd9Db270c1B5E3Bd161E8c8503c55cEABeE709552"
chain_id = 100
get_contract_metadata_task.fn(contract_address, chain_id)
cache_key = f"should_attempt_download:{contract_address}:{chain_id}:0"
redis = get_redis()
redis.delete(cache_key)
await AbiSource(name="Etherscan", url="").create(session)
etherscan_get_contract_metadata_mock.return_value = None
mock_enabled_clients.return_value = [
AsyncEtherscanClientV2(EthereumNetwork(chain_id))
]
# Should try one time
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 1)

# Shouldn't try second time
etherscan_get_contract_metadata_mock.return_value = etherscan_metadata_mock
chain_id = 100
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 1)

# After reset cache and database retries should download the contract
contract.fetch_retries = 0
redis.delete(cache_key)
await contract.update(session)
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
await session.refresh(contract)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNotNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2)

@mock.patch.object(
AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True
)
Expand Down
35 changes: 28 additions & 7 deletions app/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import dramatiq
from dramatiq.brokers.redis import RedisBroker
from dramatiq.middleware import AsyncIO
from periodiq import PeriodiqMiddleware
from hexbytes import HexBytes
from periodiq import PeriodiqMiddleware, cron
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from ..config import settings
from ..datasources.db.database import database_session
from ..services.contract_metadata_service import get_contract_metadata_service
from app.config import settings
from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.services.contract_metadata_service import get_contract_metadata_service

logger = logging.getLogger(__name__)


redis_broker = RedisBroker(url=settings.REDIS_URL)
redis_broker.add_middleware(PeriodiqMiddleware(skip_delay=60))
redis_broker.add_middleware(AsyncIO())
Expand All @@ -37,11 +40,14 @@ async def test_task(message: str) -> None:
@dramatiq.actor
@database_session
async def get_contract_metadata_task(
address: str, chain_id: int, session: AsyncSession
session: AsyncSession,
address: str,
chain_id: int,
skip_attemp_download: bool = False,
) -> None:
contract_metadata_service = get_contract_metadata_service()
# Just try the first time, following retries should be scheduled
if await contract_metadata_service.should_attempt_download(
if skip_attemp_download or await contract_metadata_service.should_attempt_download(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicky: contract_metadata_service.should_attempt_download conditional can be in a separate variable. It would be more understandable in the future. The comment only affects the second part of the condition

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because I just want to evaluate the second part of the conditional y the first is True.

session, address, chain_id, 0
):
logger.info(
Expand Down Expand Up @@ -77,6 +83,21 @@ async def get_contract_metadata_task(
address,
chain_id,
)
get_contract_metadata_task.send(proxy_implementation_address, chain_id)
get_contract_metadata_task.send(
address=proxy_implementation_address, chain_id=chain_id
)
else:
logger.debug("Skipping contract=%s and chain=%s", address, chain_id)


@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight
@database_session
async def get_missing_contract_metadata_task(session: AsyncSession) -> None:
async for contract in Contract.get_contracts_without_abi(
session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES
):
get_contract_metadata_task.send(
address=HexBytes(contract[0].address).hex(),
chain_id=contract[0].chain_id,
skip_attemp_download=True,
)
Loading