Skip to content

Commit

Permalink
501 seer add subgraph to subgraphhandler (#535)
Browse files Browse the repository at this point in the history
* Added pools call

* Fixed test

* Small fixes

* Added missing file

* Added method for retrieving seer binary markets more precisely

* Fixing filter

* Revert "Fixing filter"

This reverts commit 496c159.

* Fixing filter (2)

* Fixed outcome filter

* Fixed test

* Fixed CI
  • Loading branch information
gabrielfior authored Oct 31, 2024
1 parent 6626d8a commit 6862cbf
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 50 deletions.
51 changes: 51 additions & 0 deletions prediction_market_agent_tooling/markets/base_subgraph_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import typing as t

import tenacity
from pydantic import BaseModel
from subgrounds import FieldPath, Subgrounds

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.singleton import SingletonMeta

T = t.TypeVar("T", bound=BaseModel)


class BaseSubgraphHandler(metaclass=SingletonMeta):
def __init__(self) -> None:
self.sg = Subgrounds()
# Patch methods to retry on failure.
self.sg.query_json = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."),
)(self.sg.query_json)
self.sg.load_subgraph = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."),
)(self.sg.load_subgraph)

self.keys = APIKeys()

def _parse_items_from_json(
self, result: list[dict[str, t.Any]]
) -> list[dict[str, t.Any]]:
"""subgrounds return a weird key as a dict key"""
items = []
for result_chunk in result:
for k, v in result_chunk.items():
# subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element)
if v is None:
continue
elif isinstance(v, dict):
items.extend([v])
else:
items.extend(v)
return items

def do_query(self, fields: list[FieldPath], pydantic_model: t.Type[T]) -> list[T]:
result = self.sg.query_json(fields)
items = self._parse_items_from_json(result)
models = [pydantic_model.model_validate(i) for i in items]
return models
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
import typing as t

import requests
import tenacity
from PIL import Image
from PIL.Image import Image as ImageType
from subgrounds import FieldPath, Subgrounds
from subgrounds import FieldPath

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import (
ChecksumAddress,
HexAddress,
HexBytes,
Wei,
wei_type,
)
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.markets.agent_market import FilterBy, SortBy
from prediction_market_agent_tooling.markets.base_subgraph_handler import (
BaseSubgraphHandler,
)
from prediction_market_agent_tooling.markets.omen.data_models import (
OMEN_BINARY_MARKET_OUTCOMES,
ContractPrediction,
Expand All @@ -33,7 +33,6 @@
WrappedxDaiContract,
sDaiContract,
)
from prediction_market_agent_tooling.tools.singleton import SingletonMeta
from prediction_market_agent_tooling.tools.utils import (
DatetimeUTC,
to_int_timestamp,
Expand All @@ -51,7 +50,7 @@
)


class OmenSubgraphHandler(metaclass=SingletonMeta):
class OmenSubgraphHandler(BaseSubgraphHandler):
"""
Class responsible for handling interactions with Omen subgraphs (trades, conditionalTokens).
"""
Expand All @@ -69,47 +68,33 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"

def __init__(self) -> None:
self.sg = Subgrounds()

# Patch methods to retry on failure.
self.sg.query_json = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."),
)(self.sg.query_json)
self.sg.load_subgraph = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."),
)(self.sg.load_subgraph)

keys = APIKeys()
super().__init__()

# Load the subgraph
self.trades_subgraph = self.sg.load_subgraph(
self.OMEN_TRADES_SUBGRAPH.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.conditional_tokens_subgraph = self.sg.load_subgraph(
self.CONDITIONAL_TOKENS_SUBGRAPH.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.realityeth_subgraph = self.sg.load_subgraph(
self.REALITYETH_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.omen_image_mapping_subgraph = self.sg.load_subgraph(
self.OMEN_IMAGE_MAPPING_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)

self.omen_agent_result_mapping_subgraph = self.sg.load_subgraph(
self.OMEN_AGENT_RESULT_MAPPING_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)

Expand Down Expand Up @@ -446,22 +431,17 @@ def get_omen_binary_markets(
**optional_params,
)

omen_markets = self.do_markets_query(markets)
return omen_markets

def do_markets_query(self, markets: FieldPath) -> list[OmenMarket]:
fields = self._get_fields_for_markets(markets)
result = self.sg.query_json(fields)
items = self._parse_items_from_json(result)
omen_markets = [OmenMarket.model_validate(i) for i in items]
omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)
return omen_markets

def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket:
markets = self.trades_subgraph.Query.fixedProductMarketMaker(
id=market_id.lower()
)

omen_markets = self.do_markets_query(markets)
fields = self._get_fields_for_markets(markets)
omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)

if len(omen_markets) != 1:
raise ValueError(
Expand All @@ -470,22 +450,6 @@ def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket:

return omen_markets[0]

def _parse_items_from_json(
self, result: list[dict[str, t.Any]]
) -> list[dict[str, t.Any]]:
"""subgrounds return a weird key as a dict key"""
items = []
for result_chunk in result:
for k, v in result_chunk.items():
# subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element)
if v is None:
continue
elif isinstance(v, dict):
items.extend([v])
else:
items.extend(v)
return items

def _get_fields_for_user_positions(
self, user_positions: FieldPath
) -> list[FieldPath]:
Expand Down
27 changes: 27 additions & 0 deletions prediction_market_agent_tooling/markets/seer/data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel, ConfigDict, Field

from prediction_market_agent_tooling.gtypes import HexBytes


class SeerMarket(BaseModel):
model_config = ConfigDict(populate_by_name=True)

id: HexBytes
title: str = Field(alias="marketName")
outcomes: list[str]
parent_market: HexBytes = Field(alias="parentMarket")
wrapped_tokens: list[HexBytes] = Field(alias="wrappedTokens")


class SeerToken(BaseModel):
id: HexBytes
name: str
symbol: str


class SeerPool(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: HexBytes
liquidity: int
token0: SeerToken
token1: SeerToken
142 changes: 142 additions & 0 deletions prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Any

from subgrounds import FieldPath
from web3.constants import ADDRESS_ZERO

from prediction_market_agent_tooling.markets.base_subgraph_handler import (
BaseSubgraphHandler,
)
from prediction_market_agent_tooling.markets.seer.data_models import (
SeerMarket,
SeerPool,
)
from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes

INVALID_OUTCOME = "Invalid result"


class SeerSubgraphHandler(BaseSubgraphHandler):
"""
Class responsible for handling interactions with Seer subgraphs.
"""

SEER_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/B4vyRqJaSHD8dRDb3BFRoAzuBK18c1QQcXq94JbxDxWH"

SWAPR_ALGEBRA_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/AAA1vYjxwFHzbt6qKwLHNcDSASyr1J1xVViDH8gTMFMR"

INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"

def __init__(self) -> None:
super().__init__()

self.seer_subgraph = self.sg.load_subgraph(
self.SEER_SUBGRAPH.format(
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.swapr_algebra_subgraph = self.sg.load_subgraph(
self.SWAPR_ALGEBRA_SUBGRAPH.format(
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)

def _get_fields_for_markets(self, markets_field: FieldPath) -> list[FieldPath]:
fields = [
markets_field.id,
markets_field.factory,
markets_field.creator,
markets_field.marketName,
markets_field.outcomes,
markets_field.parentMarket,
markets_field.finalizeTs,
markets_field.wrappedTokens,
]
return fields

@staticmethod
def filter_bicategorical_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
# We do an extra check for the invalid outcome for safety.
return [
m for m in markets if len(m.outcomes) == 3 and INVALID_OUTCOME in m.outcomes
]

@staticmethod
def filter_binary_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
return [
market
for market in markets
if {"yes", "no"}.issubset({o.lower() for o in market.outcomes})
]

@staticmethod
def build_filter_for_conditional_markets(
include_conditional_markets: bool = True,
) -> dict[Any, Any]:
return (
{}
if include_conditional_markets
else {"parentMarket": ADDRESS_ZERO.lower()}
)

def get_bicategorical_markets(
self, include_conditional_markets: bool = True
) -> list[SeerMarket]:
"""Returns markets that contain 2 categories plus an invalid outcome."""
# Binary markets on Seer contain 3 outcomes: OutcomeA, outcomeB and an Invalid option.
query_filter = self.build_filter_for_conditional_markets(
include_conditional_markets
)
query_filter["outcomes_contains"] = [INVALID_OUTCOME]
markets_field = self.seer_subgraph.Query.markets(where=query_filter)
fields = self._get_fields_for_markets(markets_field)
markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
two_category_markets = self.filter_bicategorical_markets(markets)
return two_category_markets

def get_binary_markets(
self, include_conditional_markets: bool = True
) -> list[SeerMarket]:
two_category_markets = self.get_bicategorical_markets(
include_conditional_markets=include_conditional_markets
)
# Now we additionally filter markets based on YES/NO being the only outcomes.
binary_markets = self.filter_binary_markets(two_category_markets)
return binary_markets

def get_market_by_id(self, market_id: HexBytes) -> SeerMarket:
markets_field = self.seer_subgraph.Query.market(id=market_id.hex().lower())
fields = self._get_fields_for_markets(markets_field)
markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
if len(markets) != 1:
raise ValueError(
f"Fetched wrong number of markets. Expected 1 but got {len(markets)}"
)
return markets[0]

def _get_fields_for_pools(self, pools_field: FieldPath) -> list[FieldPath]:
fields = [
pools_field.id,
pools_field.liquidity,
pools_field.token0.id,
pools_field.token0.name,
pools_field.token0.symbol,
pools_field.token1.id,
pools_field.token1.name,
pools_field.token1.symbol,
]
return fields

def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
# We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
wheres = []
for wrapped_token in market.wrapped_tokens:
wheres.extend(
[
{"token0": wrapped_token.hex().lower()},
{"token1": wrapped_token.hex().lower()},
]
)
pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres})
fields = self._get_fields_for_pools(pools_field)
pools = self.do_query(fields=fields, pydantic_model=SeerPool)
return pools
Loading

0 comments on commit 6862cbf

Please sign in to comment.