Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

501 seer add subgraph to subgraphhandler #535

Merged
merged 12 commits into from
Oct 31, 2024
51 changes: 51 additions & 0 deletions prediction_market_agent_tooling/markets/base_subgraph_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import typing as t

import tenacity
from pydantic import BaseModel
from subgrounds import FieldPath, Subgrounds

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.singleton import SingletonMeta

T = t.TypeVar("T", bound=BaseModel)


class BaseSubgraphHandler(metaclass=SingletonMeta):
def __init__(self) -> None:
self.sg = Subgrounds()
# Patch methods to retry on failure.
self.sg.query_json = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."),
)(self.sg.query_json)
self.sg.load_subgraph = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."),
)(self.sg.load_subgraph)

self.keys = APIKeys()

def _parse_items_from_json(
self, result: list[dict[str, t.Any]]
) -> list[dict[str, t.Any]]:
"""subgrounds return a weird key as a dict key"""
items = []
for result_chunk in result:
for k, v in result_chunk.items():
Copy link

@coderabbitai coderabbitai bot Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Rename unused variable k to _k

The variable k in the loop is not used within the loop body. Renaming it to _k indicates that it is intentionally unused and follows Python conventions.

Apply this diff to fix the issue:

-for k, v in result_chunk.items():
+for _k, v in result_chunk.items():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for k, v in result_chunk.items():
for _k, v in result_chunk.items():
🧰 Tools
🪛 Ruff

40-40: Loop control variable k not used within loop body

Rename unused k to _k

(B007)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common that's too big neat picking! 😄 And if anything one could just use result_chunk.values() to not have the extra variable at all!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kongzii, you're absolutely right! Using result_chunk.values() would simplify the code and eliminate the need for the unused variable altogether. Thank you for pointing that out!

🐰🔍


✏️ Learnings added
Learnt from: kongzii
PR: gnosis/prediction-market-agent-tooling#535
File: prediction_market_agent_tooling/markets/base_subgraph_handler.py:37-37
Timestamp: 2024-10-31T16:29:04.548Z
Learning: In `BaseSubgraphHandler._parse_items_from_json`, when iterating over a dictionary and only the values are needed, prefer using `dict.values()` instead of `dict.items()` to avoid unused variables.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

# subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element)
if v is None:
continue
elif isinstance(v, dict):
items.extend([v])
else:
items.extend(v)
return items

def do_query(self, fields: list[FieldPath], pydantic_model: t.Type[T]) -> list[T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

result = self.sg.query_json(fields)
items = self._parse_items_from_json(result)
models = [pydantic_model.model_validate(i) for i in items]
return models
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
import typing as t

import requests
import tenacity
from PIL import Image
from PIL.Image import Image as ImageType
from subgrounds import FieldPath, Subgrounds
from subgrounds import FieldPath

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import (
ChecksumAddress,
HexAddress,
HexBytes,
Wei,
wei_type,
)
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.markets.agent_market import FilterBy, SortBy
from prediction_market_agent_tooling.markets.base_subgraph_handler import (
BaseSubgraphHandler,
)
from prediction_market_agent_tooling.markets.omen.data_models import (
OMEN_BINARY_MARKET_OUTCOMES,
ContractPrediction,
Expand All @@ -33,7 +33,6 @@
WrappedxDaiContract,
sDaiContract,
)
from prediction_market_agent_tooling.tools.singleton import SingletonMeta
from prediction_market_agent_tooling.tools.utils import (
DatetimeUTC,
to_int_timestamp,
Expand All @@ -51,7 +50,7 @@
)


class OmenSubgraphHandler(metaclass=SingletonMeta):
class OmenSubgraphHandler(BaseSubgraphHandler):
"""
Class responsible for handling interactions with Omen subgraphs (trades, conditionalTokens).
"""
Expand All @@ -69,47 +68,33 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"

def __init__(self) -> None:
self.sg = Subgrounds()

# Patch methods to retry on failure.
self.sg.query_json = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."),
)(self.sg.query_json)
self.sg.load_subgraph = tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(1),
after=lambda x: logger.debug(f"load_subgraph failed, {x.attempt_number=}."),
)(self.sg.load_subgraph)

keys = APIKeys()
super().__init__()

# Load the subgraph
self.trades_subgraph = self.sg.load_subgraph(
self.OMEN_TRADES_SUBGRAPH.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.conditional_tokens_subgraph = self.sg.load_subgraph(
self.CONDITIONAL_TOKENS_SUBGRAPH.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.realityeth_subgraph = self.sg.load_subgraph(
self.REALITYETH_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.omen_image_mapping_subgraph = self.sg.load_subgraph(
self.OMEN_IMAGE_MAPPING_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)

self.omen_agent_result_mapping_subgraph = self.sg.load_subgraph(
self.OMEN_AGENT_RESULT_MAPPING_GRAPH_URL.format(
graph_api_key=keys.graph_api_key.get_secret_value()
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)

Expand Down Expand Up @@ -446,22 +431,17 @@ def get_omen_binary_markets(
**optional_params,
)

omen_markets = self.do_markets_query(markets)
return omen_markets

def do_markets_query(self, markets: FieldPath) -> list[OmenMarket]:
fields = self._get_fields_for_markets(markets)
result = self.sg.query_json(fields)
items = self._parse_items_from_json(result)
omen_markets = [OmenMarket.model_validate(i) for i in items]
omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)
return omen_markets

def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket:
markets = self.trades_subgraph.Query.fixedProductMarketMaker(
id=market_id.lower()
)

omen_markets = self.do_markets_query(markets)
fields = self._get_fields_for_markets(markets)
omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)

if len(omen_markets) != 1:
raise ValueError(
Expand All @@ -470,22 +450,6 @@ def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket:

return omen_markets[0]

def _parse_items_from_json(
self, result: list[dict[str, t.Any]]
) -> list[dict[str, t.Any]]:
"""subgrounds return a weird key as a dict key"""
items = []
for result_chunk in result:
for k, v in result_chunk.items():
# subgrounds might pack all items as a list, indexed by a key, or pack it as a dictionary (if one single element)
if v is None:
continue
elif isinstance(v, dict):
items.extend([v])
else:
items.extend(v)
return items

def _get_fields_for_user_positions(
self, user_positions: FieldPath
) -> list[FieldPath]:
Expand Down
27 changes: 27 additions & 0 deletions prediction_market_agent_tooling/markets/seer/data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel, ConfigDict, Field

from prediction_market_agent_tooling.gtypes import HexBytes


class SeerMarket(BaseModel):
model_config = ConfigDict(populate_by_name=True)

id: HexBytes
title: str = Field(alias="marketName")
outcomes: list[str]
parent_market: HexBytes = Field(alias="parentMarket")
wrapped_tokens: list[HexBytes] = Field(alias="wrappedTokens")


class SeerToken(BaseModel):
id: HexBytes
name: str
symbol: str


class SeerPool(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: HexBytes
liquidity: int
token0: SeerToken
token1: SeerToken
142 changes: 142 additions & 0 deletions prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Any

from subgrounds import FieldPath
from web3.constants import ADDRESS_ZERO

from prediction_market_agent_tooling.markets.base_subgraph_handler import (
BaseSubgraphHandler,
)
from prediction_market_agent_tooling.markets.seer.data_models import (
SeerMarket,
SeerPool,
)
from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes

INVALID_OUTCOME = "Invalid result"


class SeerSubgraphHandler(BaseSubgraphHandler):
"""
Class responsible for handling interactions with Seer subgraphs.
"""

SEER_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/B4vyRqJaSHD8dRDb3BFRoAzuBK18c1QQcXq94JbxDxWH"

SWAPR_ALGEBRA_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/AAA1vYjxwFHzbt6qKwLHNcDSASyr1J1xVViDH8gTMFMR"

INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
gabrielfior marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
super().__init__()

self.seer_subgraph = self.sg.load_subgraph(
self.SEER_SUBGRAPH.format(
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
self.swapr_algebra_subgraph = self.sg.load_subgraph(
self.SWAPR_ALGEBRA_SUBGRAPH.format(
graph_api_key=self.keys.graph_api_key.get_secret_value()
)
)
Comment on lines +29 to +41
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for subgraph loading failures

Currently, if the subgraphs fail to load due to an invalid graph_api_key or network issues, the constructor will raise an exception, possibly causing the application to crash. Consider adding try-except blocks to handle exceptions during subgraph loading and provide informative error messages or fallback mechanisms.

Apply this change to add error handling:

def __init__(self) -> None:
    super().__init__()

    try:
        self.seer_subgraph = self.sg.load_subgraph(
            self.SEER_SUBGRAPH.format(
                graph_api_key=self.keys.graph_api_key.get_secret_value()
            )
        )
        self.swapr_algebra_subgraph = self.sg.load_subgraph(
            self.SWAPR_ALGEBRA_SUBGRAPH.format(
                graph_api_key=self.keys.graph_api_key.get_secret_value()
            )
        )
    except Exception as e:
        # Handle the exception or log an error message
        raise ConnectionError(f"Failed to load subgraphs: {e}")


def _get_fields_for_markets(self, markets_field: FieldPath) -> list[FieldPath]:
fields = [
markets_field.id,
markets_field.factory,
markets_field.creator,
markets_field.marketName,
markets_field.outcomes,
markets_field.parentMarket,
markets_field.finalizeTs,
markets_field.wrappedTokens,
]
return fields

@staticmethod
def filter_bicategorical_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
# We do an extra check for the invalid outcome for safety.
return [
m for m in markets if len(m.outcomes) == 3 and INVALID_OUTCOME in m.outcomes
]

@staticmethod
def filter_binary_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
return [
market
for market in markets
if {"yes", "no"}.issubset({o.lower() for o in market.outcomes})
]

@staticmethod
def build_filter_for_conditional_markets(
include_conditional_markets: bool = True,
) -> dict[Any, Any]:
return (
{}
if include_conditional_markets
else {"parentMarket": ADDRESS_ZERO.lower()}
)

def get_bicategorical_markets(
self, include_conditional_markets: bool = True
) -> list[SeerMarket]:
"""Returns markets that contain 2 categories plus an invalid outcome."""
# Binary markets on Seer contain 3 outcomes: OutcomeA, outcomeB and an Invalid option.
query_filter = self.build_filter_for_conditional_markets(
include_conditional_markets
)
query_filter["outcomes_contains"] = [INVALID_OUTCOME]
markets_field = self.seer_subgraph.Query.markets(where=query_filter)
fields = self._get_fields_for_markets(markets_field)
markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
two_category_markets = self.filter_bicategorical_markets(markets)
return two_category_markets

def get_binary_markets(
self, include_conditional_markets: bool = True
) -> list[SeerMarket]:
two_category_markets = self.get_bicategorical_markets(
include_conditional_markets=include_conditional_markets
)
# Now we additionally filter markets based on YES/NO being the only outcomes.
binary_markets = self.filter_binary_markets(two_category_markets)
return binary_markets

gabrielfior marked this conversation as resolved.
Show resolved Hide resolved
def get_market_by_id(self, market_id: HexBytes) -> SeerMarket:
markets_field = self.seer_subgraph.Query.market(id=market_id.hex().lower())
fields = self._get_fields_for_markets(markets_field)
markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
if len(markets) != 1:
raise ValueError(
f"Fetched wrong number of markets. Expected 1 but got {len(markets)}"
)
Comment on lines +110 to +113
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Include market_id in the error message for clarity

Including the market_id in the error message provides context and aids in debugging when the expected market is not retrieved.

Apply this diff to improve the error message:

if len(markets) != 1:
    raise ValueError(
-       f"Fetched wrong number of markets. Expected 1 but got {len(markets)}"
+       f"Fetched wrong number of markets for ID {market_id.hex()}. Expected 1 but got {len(markets)}"
    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if len(markets) != 1:
raise ValueError(
f"Fetched wrong number of markets. Expected 1 but got {len(markets)}"
)
if len(markets) != 1:
raise ValueError(
f"Fetched wrong number of markets for ID {market_id.hex()}. Expected 1 but got {len(markets)}"
)

return markets[0]

def _get_fields_for_pools(self, pools_field: FieldPath) -> list[FieldPath]:
fields = [
pools_field.id,
pools_field.liquidity,
pools_field.token0.id,
pools_field.token0.name,
pools_field.token0.symbol,
pools_field.token1.id,
pools_field.token1.name,
pools_field.token1.symbol,
]
return fields

def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
# We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
wheres = []
for wrapped_token in market.wrapped_tokens:
wheres.extend(
[
{"token0": wrapped_token.hex().lower()},
{"token1": wrapped_token.hex().lower()},
]
)
pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres})
fields = self._get_fields_for_pools(pools_field)
pools = self.do_query(fields=fields, pydantic_model=SeerPool)
return pools
Comment on lines +129 to +142
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add input validation for market parameter.

The method should validate that the input market is not None before proceeding.

def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
+    if market is None:
+        raise ValueError("Market cannot be None")
+
    # We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
    wheres = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
# We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
wheres = []
for wrapped_token in market.wrapped_tokens:
wheres.extend(
[
{"token0": wrapped_token.hex().lower()},
{"token1": wrapped_token.hex().lower()},
]
)
pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres})
fields = self._get_fields_for_pools(pools_field)
pools = self.do_query(fields=fields, pydantic_model=SeerPool)
return pools
def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
if market is None:
raise ValueError("Market cannot be None")
# We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
wheres = []
for wrapped_token in market.wrapped_tokens:
wheres.extend(
[
{"token0": wrapped_token.hex().lower()},
{"token1": wrapped_token.hex().lower()},
]
)
pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres})
fields = self._get_fields_for_pools(pools_field)
pools = self.do_query(fields=fields, pydantic_model=SeerPool)
return pools

Loading
Loading