diff --git a/prediction_market_agent_tooling/markets/base_subgraph_handler.py b/prediction_market_agent_tooling/markets/base_subgraph_handler.py new file mode 100644 index 00000000..0fc93a6c --- /dev/null +++ b/prediction_market_agent_tooling/markets/base_subgraph_handler.py @@ -0,0 +1,51 @@ +import typing as t + +import tenacity +from pydantic import BaseModel +from subgrounds import FieldPath, Subgrounds + +from prediction_market_agent_tooling.config import APIKeys +from prediction_market_agent_tooling.loggers import logger +from prediction_market_agent_tooling.tools.singleton import SingletonMeta + +T = t.TypeVar("T", bound=BaseModel) + + +class BaseSubgraphHandler(metaclass=SingletonMeta): + def __init__(self) -> None: + self.sg = Subgrounds() + # Patch methods to retry on failure. + self.sg.query_json = tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_fixed(1), + after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."), + )(self.sg.query_json) + self.sg.load_subgraph = tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_fixed(1), + after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."), + )(self.sg.load_subgraph) + + self.keys = APIKeys() + + def _parse_items_from_json( + self, result: list[dict[str, t.Any]] + ) -> list[dict[str, t.Any]]: + """subgrounds return a weird key as a dict key""" + items = [] + for result_chunk in result: + for k, v in result_chunk.items(): + # subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element) + if v is None: + continue + elif isinstance(v, dict): + items.extend([v]) + else: + items.extend(v) + return items + + def do_query(self, fields: list[FieldPath], pydantic_model: t.Type[T]) -> list[T]: + result = self.sg.query_json(fields) + items = self._parse_items_from_json(result) + models = [pydantic_model.model_validate(i) for i in items] + return models diff --git a/prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py b/prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py index e1513c5f..ebd662ed 100644 --- a/prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py +++ b/prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py @@ -2,12 +2,10 @@ import typing as t import requests -import tenacity from PIL import Image from PIL.Image import Image as ImageType -from subgrounds import FieldPath, Subgrounds +from subgrounds import FieldPath -from prediction_market_agent_tooling.config import APIKeys from prediction_market_agent_tooling.gtypes import ( ChecksumAddress, HexAddress, @@ -15,8 +13,10 @@ Wei, wei_type, ) -from prediction_market_agent_tooling.loggers import logger from prediction_market_agent_tooling.markets.agent_market import FilterBy, SortBy +from prediction_market_agent_tooling.markets.base_subgraph_handler import ( + BaseSubgraphHandler, +) from prediction_market_agent_tooling.markets.omen.data_models import ( OMEN_BINARY_MARKET_OUTCOMES, ContractPrediction, @@ -33,7 +33,6 @@ WrappedxDaiContract, sDaiContract, ) -from prediction_market_agent_tooling.tools.singleton import SingletonMeta from prediction_market_agent_tooling.tools.utils import ( DatetimeUTC, to_int_timestamp, @@ -51,7 +50,7 @@ ) -class OmenSubgraphHandler(metaclass=SingletonMeta): +class OmenSubgraphHandler(BaseSubgraphHandler): """ Class responsible for handling interactions with Omen subgraphs (trades, conditionalTokens). """ @@ -69,47 +68,33 @@ class OmenSubgraphHandler(metaclass=SingletonMeta): INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" def __init__(self) -> None: - self.sg = Subgrounds() - - # Patch methods to retry on failure. - self.sg.query_json = tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_fixed(1), - after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."), - )(self.sg.query_json) - self.sg.load_subgraph = tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_fixed(1), - after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."), - )(self.sg.load_subgraph) - - keys = APIKeys() + super().__init__() # Load the subgraph self.trades_subgraph = self.sg.load_subgraph( self.OMEN_TRADES_SUBGRAPH.format( - graph_api_key=keys.graph_api_key.get_secret_value() + graph_api_key=self.keys.graph_api_key.get_secret_value() ) ) self.conditional_tokens_subgraph = self.sg.load_subgraph( self.CONDITIONAL_TOKENS_SUBGRAPH.format( - graph_api_key=keys.graph_api_key.get_secret_value() + graph_api_key=self.keys.graph_api_key.get_secret_value() ) ) self.realityeth_subgraph = self.sg.load_subgraph( self.REALITYETH_GRAPH_URL.format( - graph_api_key=keys.graph_api_key.get_secret_value() + graph_api_key=self.keys.graph_api_key.get_secret_value() ) ) self.omen_image_mapping_subgraph = self.sg.load_subgraph( self.OMEN_IMAGE_MAPPING_GRAPH_URL.format( - graph_api_key=keys.graph_api_key.get_secret_value() + graph_api_key=self.keys.graph_api_key.get_secret_value() ) ) self.omen_agent_result_mapping_subgraph = self.sg.load_subgraph( self.OMEN_AGENT_RESULT_MAPPING_GRAPH_URL.format( - graph_api_key=keys.graph_api_key.get_secret_value() + graph_api_key=self.keys.graph_api_key.get_secret_value() ) ) @@ -446,14 +431,8 @@ def get_omen_binary_markets( **optional_params, ) - omen_markets = self.do_markets_query(markets) - return omen_markets - - def do_markets_query(self, markets: FieldPath) -> list[OmenMarket]: fields = self._get_fields_for_markets(markets) - result = self.sg.query_json(fields) - items = self._parse_items_from_json(result) - omen_markets = [OmenMarket.model_validate(i) for i in items] + omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket) return omen_markets def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket: @@ -461,7 +440,8 @@ def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket: id=market_id.lower() ) - omen_markets = self.do_markets_query(markets) + fields = self._get_fields_for_markets(markets) + omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket) if len(omen_markets) != 1: raise ValueError( @@ -470,22 +450,6 @@ def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket: return omen_markets[0] - def _parse_items_from_json( - self, result: list[dict[str, t.Any]] - ) -> list[dict[str, t.Any]]: - """subgrounds return a weird key as a dict key""" - items = [] - for result_chunk in result: - for k, v in result_chunk.items(): - # subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element) - if v is None: - continue - elif isinstance(v, dict): - items.extend([v]) - else: - items.extend(v) - return items - def _get_fields_for_user_positions( self, user_positions: FieldPath ) -> list[FieldPath]: diff --git a/prediction_market_agent_tooling/markets/seer/data_models.py b/prediction_market_agent_tooling/markets/seer/data_models.py new file mode 100644 index 00000000..a0a45838 --- /dev/null +++ b/prediction_market_agent_tooling/markets/seer/data_models.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel, ConfigDict, Field + +from prediction_market_agent_tooling.gtypes import HexBytes + + +class SeerMarket(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + id: HexBytes + title: str = Field(alias="marketName") + outcomes: list[str] + parent_market: HexBytes = Field(alias="parentMarket") + wrapped_tokens: list[HexBytes] = Field(alias="wrappedTokens") + + +class SeerToken(BaseModel): + id: HexBytes + name: str + symbol: str + + +class SeerPool(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: HexBytes + liquidity: int + token0: SeerToken + token1: SeerToken diff --git a/prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py b/prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py new file mode 100644 index 00000000..69845fbe --- /dev/null +++ b/prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py @@ -0,0 +1,142 @@ +from typing import Any + +from subgrounds import FieldPath +from web3.constants import ADDRESS_ZERO + +from prediction_market_agent_tooling.markets.base_subgraph_handler import ( + BaseSubgraphHandler, +) +from prediction_market_agent_tooling.markets.seer.data_models import ( + SeerMarket, + SeerPool, +) +from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes + +INVALID_OUTCOME = "Invalid result" + + +class SeerSubgraphHandler(BaseSubgraphHandler): + """ + Class responsible for handling interactions with Seer subgraphs. + """ + + SEER_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/B4vyRqJaSHD8dRDb3BFRoAzuBK18c1QQcXq94JbxDxWH" + + SWAPR_ALGEBRA_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/AAA1vYjxwFHzbt6qKwLHNcDSASyr1J1xVViDH8gTMFMR" + + INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + def __init__(self) -> None: + super().__init__() + + self.seer_subgraph = self.sg.load_subgraph( + self.SEER_SUBGRAPH.format( + graph_api_key=self.keys.graph_api_key.get_secret_value() + ) + ) + self.swapr_algebra_subgraph = self.sg.load_subgraph( + self.SWAPR_ALGEBRA_SUBGRAPH.format( + graph_api_key=self.keys.graph_api_key.get_secret_value() + ) + ) + + def _get_fields_for_markets(self, markets_field: FieldPath) -> list[FieldPath]: + fields = [ + markets_field.id, + markets_field.factory, + markets_field.creator, + markets_field.marketName, + markets_field.outcomes, + markets_field.parentMarket, + markets_field.finalizeTs, + markets_field.wrappedTokens, + ] + return fields + + @staticmethod + def filter_bicategorical_markets(markets: list[SeerMarket]) -> list[SeerMarket]: + # We do an extra check for the invalid outcome for safety. + return [ + m for m in markets if len(m.outcomes) == 3 and INVALID_OUTCOME in m.outcomes + ] + + @staticmethod + def filter_binary_markets(markets: list[SeerMarket]) -> list[SeerMarket]: + return [ + market + for market in markets + if {"yes", "no"}.issubset({o.lower() for o in market.outcomes}) + ] + + @staticmethod + def build_filter_for_conditional_markets( + include_conditional_markets: bool = True, + ) -> dict[Any, Any]: + return ( + {} + if include_conditional_markets + else {"parentMarket": ADDRESS_ZERO.lower()} + ) + + def get_bicategorical_markets( + self, include_conditional_markets: bool = True + ) -> list[SeerMarket]: + """Returns markets that contain 2 categories plus an invalid outcome.""" + # Binary markets on Seer contain 3 outcomes: OutcomeA, outcomeB and an Invalid option. + query_filter = self.build_filter_for_conditional_markets( + include_conditional_markets + ) + query_filter["outcomes_contains"] = [INVALID_OUTCOME] + markets_field = self.seer_subgraph.Query.markets(where=query_filter) + fields = self._get_fields_for_markets(markets_field) + markets = self.do_query(fields=fields, pydantic_model=SeerMarket) + two_category_markets = self.filter_bicategorical_markets(markets) + return two_category_markets + + def get_binary_markets( + self, include_conditional_markets: bool = True + ) -> list[SeerMarket]: + two_category_markets = self.get_bicategorical_markets( + include_conditional_markets=include_conditional_markets + ) + # Now we additionally filter markets based on YES/NO being the only outcomes. + binary_markets = self.filter_binary_markets(two_category_markets) + return binary_markets + + def get_market_by_id(self, market_id: HexBytes) -> SeerMarket: + markets_field = self.seer_subgraph.Query.market(id=market_id.hex().lower()) + fields = self._get_fields_for_markets(markets_field) + markets = self.do_query(fields=fields, pydantic_model=SeerMarket) + if len(markets) != 1: + raise ValueError( + f"Fetched wrong number of markets. Expected 1 but got {len(markets)}" + ) + return markets[0] + + def _get_fields_for_pools(self, pools_field: FieldPath) -> list[FieldPath]: + fields = [ + pools_field.id, + pools_field.liquidity, + pools_field.token0.id, + pools_field.token0.name, + pools_field.token0.symbol, + pools_field.token1.id, + pools_field.token1.name, + pools_field.token1.symbol, + ] + return fields + + def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]: + # We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once. + wheres = [] + for wrapped_token in market.wrapped_tokens: + wheres.extend( + [ + {"token0": wrapped_token.hex().lower()}, + {"token1": wrapped_token.hex().lower()}, + ] + ) + pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres}) + fields = self._get_fields_for_pools(pools_field) + pools = self.do_query(fields=fields, pydantic_model=SeerPool) + return pools diff --git a/tests_integration/markets/seer/test_seer_subgraph_handler.py b/tests_integration/markets/seer/test_seer_subgraph_handler.py new file mode 100644 index 00000000..a4c30fb2 --- /dev/null +++ b/tests_integration/markets/seer/test_seer_subgraph_handler.py @@ -0,0 +1,62 @@ +import typing as t + +import pytest + +from prediction_market_agent_tooling.markets.seer.seer_subgraph_handler import ( + SeerSubgraphHandler, +) +from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes + +CONDITIONAL_MARKET_ID = HexBytes("0xe12f48ecdd6e64d95d1d8f1d5d7aa37e14f2888b") +BINARY_MARKET_ID = HexBytes("0x7d72aa56ecdda207005fd7a02dbfd33f92d0def7") +BINARY_CONDITIONAL_MARKET_ID = HexBytes("0xbc82402814f7db8736980c0debb01df6aad8846e") + + +@pytest.fixture(scope="module") +def handler() -> t.Generator[SeerSubgraphHandler, None, None]: + yield SeerSubgraphHandler() + + +def test_get_all_seer_markets(handler: SeerSubgraphHandler) -> None: + markets = handler.get_bicategorical_markets() + assert len(markets) > 1 + + +def test_get_seer_market_by_id(handler: SeerSubgraphHandler) -> None: + market_id = HexBytes("0x03cbd8e3a45c727643b015318fff883e13937fdd") + market = handler.get_market_by_id(market_id) + assert market is not None + assert market.id == market_id + + +def test_conditional_market_not_retrieved(handler: SeerSubgraphHandler) -> None: + markets = handler.get_bicategorical_markets(include_conditional_markets=False) + market_ids = [m.id for m in markets] + assert CONDITIONAL_MARKET_ID not in market_ids + + +def test_conditional_market_retrieved(handler: SeerSubgraphHandler) -> None: + markets = handler.get_bicategorical_markets(include_conditional_markets=True) + market_ids = [m.id for m in markets] + assert CONDITIONAL_MARKET_ID in market_ids + + +def test_binary_market_retrieved(handler: SeerSubgraphHandler) -> None: + markets = handler.get_binary_markets(include_conditional_markets=True) + market_ids = [m.id for m in markets] + assert BINARY_MARKET_ID in market_ids + assert BINARY_CONDITIONAL_MARKET_ID in market_ids + + +def test_get_pools_for_market(handler: SeerSubgraphHandler) -> None: + us_election_market_id = HexBytes("0x43d881f5920ed29fc5cd4917d6817496abbba6d9") + market = handler.get_market_by_id(us_election_market_id) + + pools = handler.get_pools_for_market(market) + assert len(pools) > 1 + for pool in pools: + # one of the tokens must be a wrapped token + assert ( + pool.token0.id in market.wrapped_tokens + or pool.token1.id in market.wrapped_tokens + )