Skip to content

Commit

Permalink
GetWalletBalance -> GetBalance (#65)
Browse files Browse the repository at this point in the history
* GetWalletBalance -> GetBalance
  • Loading branch information
evangriffiths authored Apr 10, 2024
1 parent 83ff1be commit 0348176
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 30 deletions.
34 changes: 6 additions & 28 deletions prediction_market_agent/agents/microchain_agent/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pprint
import typing as t
from decimal import Decimal
from typing import cast
Expand All @@ -16,10 +15,10 @@
from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import (
OmenSubgraphHandler,
)
from prediction_market_agent_tooling.tools.balances import get_balances

from prediction_market_agent.agents.microchain_agent.utils import (
MicroMarket,
get_balance,
get_binary_market_from_question,
get_binary_markets,
get_market_token_balance,
Expand All @@ -28,7 +27,6 @@
)
from prediction_market_agent.utils import APIKeys

balance = 50
outcomeTokens = {}
outcomeTokens["Will Joe Biden get reelected in 2024?"] = {"yes": 0, "no": 0}
outcomeTokens["Will Bitcoin hit 100k in 2024?"] = {"yes": 0, "no": 0}
Expand Down Expand Up @@ -100,21 +98,6 @@ def __call__(self, a: str) -> float:
return 0.0


class GetBalance(MarketFunction):
@property
def description(self) -> str:
return "Use this function to get your own balance in $"

@property
def example_args(self) -> list[str]:
return []

def __call__(self) -> float:
print(f"Your balance is: {balance} and ")
pprint.pprint(outcomeTokens)
return balance


class BuyTokens(MarketFunction):
def __init__(self, market_type: MarketType, outcome: str):
self.outcome = outcome
Expand Down Expand Up @@ -222,25 +205,21 @@ def example_args(self) -> list[str]:
]

def __call__(self, summary: str) -> str:
# print(summary)
# pprint.pprint(outcomeTokens)
return summary


class GetWalletBalance(MarketFunction):
class GetBalance(MarketFunction):
@property
def description(self) -> str:
return "Use this function to fetch your balance, given in xDAI units."
currency = self.market_type.market_class.currency
return f"Use this function to fetch your balance, given in {currency} units."

@property
def example_args(self) -> list[str]:
return []

def __call__(self) -> Decimal:
# We focus solely on xDAI balance for now to avoid the agent having to wrap/unwrap xDAI.
user_address_checksummed = APIKeys().bet_from_address
balance = get_balances(user_address_checksummed)
return balance.xdai
return get_balance(market_type=self.market_type).amount


class GetUserPositions(MarketFunction):
Expand All @@ -263,7 +242,7 @@ def __call__(self, user_address: str) -> list[OmenUserPosition]:
MISC_FUNCTIONS = [
Sum,
Product,
SummarizeLearning,
# SummarizeLearning,
]

# Functions that interact with the prediction markets
Expand All @@ -275,6 +254,5 @@ def __call__(self, user_address: str) -> list[OmenUserPosition]:
BuyNo,
SellYes,
SellNo,
GetWalletBalance,
GetUserPositions,
]
18 changes: 18 additions & 0 deletions prediction_market_agent/agents/microchain_agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from decimal import Decimal

from eth_typing import ChecksumAddress
from prediction_market_agent_tooling.markets.agent_market import (
AgentMarket,
FilterBy,
SortBy,
)
from prediction_market_agent_tooling.markets.data_models import BetAmount
from prediction_market_agent_tooling.markets.markets import MarketType
from prediction_market_agent_tooling.markets.omen.data_models import (
OMEN_FALSE_OUTCOME,
Expand All @@ -15,10 +18,13 @@
from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import (
OmenSubgraphHandler,
)
from prediction_market_agent_tooling.tools.balances import get_balances
from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
from pydantic import BaseModel
from web3.types import Wei

from prediction_market_agent.utils import APIKeys


class MicroMarket(BaseModel):
question: str
Expand Down Expand Up @@ -50,6 +56,18 @@ def get_binary_markets(market_type: MarketType) -> list[AgentMarket]:
return markets


def get_balance(market_type: MarketType) -> BetAmount:
currency = market_type.market_class.currency
if market_type == MarketType.OMEN:
# We focus solely on xDAI balance for now to avoid the agent having to wrap/unwrap xDAI.
return BetAmount(
amount=Decimal(get_balances(APIKeys().bet_from_address).xdai),
currency=currency,
)
else:
raise ValueError(f"Market type '{market_type}' not supported")


def get_binary_market_from_question(
market: str, market_type: MarketType
) -> AgentMarket:
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_microchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
MISC_FUNCTIONS,
BuyNo,
BuyYes,
GetBalance,
GetMarkets,
GetUserPositions,
GetWalletBalance,
)
from prediction_market_agent.agents.microchain_agent.utils import (
get_binary_markets,
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_buy_no(market_type: MarketType) -> None:

@pytest.mark.parametrize("market_type", [MarketType.OMEN])
def test_replicator_has_balance_gt_0(market_type: MarketType) -> None:
balance = GetWalletBalance(market_type=market_type)()
balance = GetBalance(market_type=market_type)()
assert balance > 0


Expand Down

0 comments on commit 0348176

Please sign in to comment.