Skip to content

Commit

Permalink
NFT Game Status in agents (#640)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Jan 17, 2025
1 parent 9799074 commit 43e22f4
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from microchain import Function
from prediction_market_agent_tooling.tools.utils import utcnow


class TodayDate(Function):
@property
def description(self) -> str:
return "Use this function to get the current date."

@property
def example_args(self) -> list[str]:
return []

def __call__(self) -> str:
now = utcnow()
return f"Today is {now.strftime('%Y-%m-%d %H:%M:%S')}. The day is {now.strftime('%A')}."


COMMON_FUNCTIONS: list[type[Function]] = [
TodayDate,
]
144 changes: 87 additions & 57 deletions prediction_market_agent/agents/microchain_agent/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@


class DeployableMicrochainAgentAbstract(DeployableAgent, metaclass=abc.ABCMeta):
# Setup per-agent class.
model = SupportedModel.gpt_4o
max_iterations: int | None = 50
import_actions_from_memory = 0
sleep_between_iterations = 0
identifier: AgentIdentifier
functions_config: FunctionsConfig

# Setup during the 'load' method.
long_term_memory: LongTermMemoryTableHandler
prompt_handler: PromptTableHandler
agent: Agent
goal_manager: GoalManager | None

@classmethod
def get_description(cls) -> str:
return f"Microchain-based {cls.__name__}."
Expand All @@ -54,12 +61,42 @@ def get_description(cls) -> str:
def get_initial_system_prompt(cls) -> str:
pass

def build_long_term_memory(self) -> LongTermMemoryTableHandler:
return LongTermMemoryTableHandler.from_agent_identifier(self.identifier)

def build_prompt_handler(self) -> PromptTableHandler:
return PromptTableHandler.from_agent_identifier(self.identifier)

def build_agent(self, market_type: MarketType) -> Agent:
unformatted_system_prompt = get_unformatted_system_prompt(
unformatted_prompt=self.get_initial_system_prompt(),
prompt_table_handler=self.prompt_handler,
)

return build_agent(
market_type=market_type,
model=self.model,
unformatted_system_prompt=unformatted_system_prompt,
allow_stop=True,
long_term_memory=self.long_term_memory,
import_actions_from_memory=self.import_actions_from_memory,
keys=APIKeys(),
functions_config=self.functions_config,
enable_langfuse=self.enable_langfuse,
)

def build_goal_manager(
self,
agent: Agent,
) -> GoalManager | None:
return None

def load(self) -> None:
self.long_term_memory = self.build_long_term_memory()
self.prompt_handler = self.build_prompt_handler()
self.agent = self.build_agent(market_type=MarketType.OMEN)
self.goal_manager = self.build_goal_manager(agent=self.agent)

def run(
self,
market_type: MarketType,
Expand All @@ -71,106 +108,99 @@ def run(
self.run_general_agent(market_type=market_type)

@observe()
def run_general_agent(
self,
market_type: MarketType,
) -> None:
self.langfuse_update_current_trace(tags=[GENERAL_AGENT_TAG, self.identifier])

long_term_memory = LongTermMemoryTableHandler.from_agent_identifier(
self.identifier
)
prompt_handler = PromptTableHandler.from_agent_identifier(self.identifier)
unformatted_system_prompt = get_unformatted_system_prompt(
unformatted_prompt=self.get_initial_system_prompt(),
prompt_table_handler=prompt_handler,
)
def run_general_agent(self, market_type: MarketType) -> None:
if market_type != MarketType.OMEN:
raise ValueError(f"Only {MarketType.OMEN} market type is supported.")

agent: Agent = build_agent(
market_type=market_type,
model=self.model,
unformatted_system_prompt=unformatted_system_prompt,
allow_stop=True,
long_term_memory=long_term_memory,
import_actions_from_memory=self.import_actions_from_memory,
keys=APIKeys(),
functions_config=self.functions_config,
enable_langfuse=self.enable_langfuse,
)
self.langfuse_update_current_trace(tags=[GENERAL_AGENT_TAG, self.identifier])

goal_manager = self.build_goal_manager(agent=agent)
goal = goal_manager.get_goal() if goal_manager else None
goal = self.goal_manager.get_goal() if self.goal_manager else None
if goal:
agent.prompt = goal.to_prompt()
self.agent.prompt = goal.to_prompt()

# Save formatted system prompt
initial_formatted_system_prompt = agent.system_prompt
initial_formatted_system_prompt = self.agent.system_prompt

iteration = 0
while not agent.do_stop and (

while not self.agent.do_stop and (
self.max_iterations is None or iteration < self.max_iterations
):
starting_history_length = len(agent.history)
self.before_iteration_callback()

starting_history_length = len(self.agent.history)
try:
# After the first iteration, resume=True to not re-initialize the agent.
agent.run(iterations=1, resume=iteration > 0)
self.agent.run(iterations=1, resume=iteration > 0)
except Exception as e:
logger.error(f"Error while running microchain agent: {e}")
raise e
finally:
# Save the agent's history to the long-term memory after every iteration to keep users updated.
save_agent_history(
agent=agent,
long_term_memory=long_term_memory,
initial_system_prompt=initial_formatted_system_prompt,
# Because the agent is running in a while cycle, always save into database only what's new, to not duplicate entries.
save_last_n=len(agent.history) - starting_history_length,
self.save_agent_history(
initial_formatted_system_prompt=initial_formatted_system_prompt,
save_last_n=len(self.agent.history) - starting_history_length,
)
if agent.system_prompt != initial_formatted_system_prompt:
prompt_handler.save_prompt(get_editable_prompt_from_agent(agent))
if self.agent.system_prompt != initial_formatted_system_prompt:
self.prompt_handler.save_prompt(
get_editable_prompt_from_agent(self.agent)
)

iteration += 1
logger.info(f"{self.__class__.__name__} iteration {iteration} completed.")

self.after_iteration_callback()

if self.sleep_between_iterations:
logger.info(
f"{self.__class__.__name__} sleeping for {self.sleep_between_iterations} seconds."
)
time.sleep(self.sleep_between_iterations)

if goal_manager:
if self.goal_manager:
self.handle_goal_evaluation(
agent,
check_not_none(goal),
goal_manager,
long_term_memory,
initial_formatted_system_prompt,
check_not_none(goal), initial_formatted_system_prompt
)

def save_agent_history(
self, initial_formatted_system_prompt: str, save_last_n: int
) -> None:
save_agent_history(
agent=self.agent,
long_term_memory=self.long_term_memory,
initial_system_prompt=initial_formatted_system_prompt,
# Because the agent is running in a while cycle, always save into database only what's new, to not duplicate entries.
save_last_n=save_last_n,
)

def before_iteration_callback(self) -> None:
pass

def after_iteration_callback(self) -> None:
pass

def handle_goal_evaluation(
self,
agent: Agent,
goal: Goal,
goal_manager: GoalManager,
long_term_memory: LongTermMemoryTableHandler,
initial_formatted_system_prompt: str,
) -> None:
goal_evaluation = goal_manager.evaluate_goal_progress(
assert self.goal_manager is not None, "Goal manager must be set."
goal_evaluation = self.goal_manager.evaluate_goal_progress(
goal=goal,
chat_history=ChatHistory.from_list_of_dicts(agent.history),
chat_history=ChatHistory.from_list_of_dicts(self.agent.history),
)
goal_manager.save_evaluated_goal(
self.goal_manager.save_evaluated_goal(
goal=goal,
evaluation=goal_evaluation,
)
agent.history.append(
self.agent.history.append(
ChatMessage(
role="user",
content=f"# Goal evaluation\n{goal_evaluation}",
).model_dump()
)
save_agent_history(
agent=agent,
long_term_memory=long_term_memory,
initial_system_prompt=initial_formatted_system_prompt,
self.save_agent_history(
initial_formatted_system_prompt=initial_formatted_system_prompt,
# Save only the new (last) message, which is the goal evaluation.
save_last_n=1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

class LongTermMemoryBasedFunction(Function):
def __init__(
self, long_term_memory: LongTermMemoryTableHandler, model: str
self,
long_term_memory: LongTermMemoryTableHandler,
model: str = "gpt-4o", # Use model that works well with these functions, provide other one only if necessary.
) -> None:
self.long_term_memory = long_term_memory
self.model = model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from prediction_market_agent.agents.microchain_agent.code_functions import (
CODE_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.common_functions import (
COMMON_FUNCTIONS,
)
from prediction_market_agent.agents.microchain_agent.jobs_functions import JOB_FUNCTIONS
from prediction_market_agent.agents.microchain_agent.learning_functions import (
LEARNING_FUNCTIONS,
Expand Down Expand Up @@ -134,6 +137,9 @@ def build_agent_functions(
if allow_stop:
functions.append(Stop())

if functions_config.common_functions:
functions.extend(f() for f in COMMON_FUNCTIONS)

if functions_config.include_agent_functions:
functions.extend([f(agent=agent) for f in AGENT_FUNCTIONS])

Expand Down Expand Up @@ -171,9 +177,7 @@ def build_agent_functions(
functions.extend(f() for f in BALANCE_FUNCTIONS)

if long_term_memory:
functions.extend(
f(long_term_memory=long_term_memory, model=model) for f in MEMORY_FUNCTIONS
)
functions.extend(f(long_term_memory=long_term_memory) for f in MEMORY_FUNCTIONS)

return functions

Expand All @@ -188,6 +192,7 @@ def build_agent(
api_base: str = "https://api.openai.com/v1",
long_term_memory: LongTermMemoryTableHandler | None = None,
import_actions_from_memory: int = 0,
max_tokens: int = 8196,
allow_stop: bool = True,
bootstrap: str | None = None,
raise_on_error: bool = True,
Expand All @@ -200,6 +205,7 @@ def build_agent(
api_base=api_base,
temperature=0.7,
enable_langfuse=enable_langfuse,
max_tokens=max_tokens,
)
if model.is_openai
else (
Expand All @@ -210,6 +216,7 @@ def build_agent(
),
api_key=keys.replicate_api_key.get_secret_value(),
enable_langfuse=enable_langfuse,
max_tokens=max_tokens,
)
if model.is_replicate
else should_not_happen()
Expand Down Expand Up @@ -239,7 +246,9 @@ def step_end_callback(agent: Agent, step_output: StepOutput) -> None:
).search(limit=import_actions_from_memory)
agent.history.extend(
m.metadata_dict
for m in latest_saved_memories
for m in latest_saved_memories[
::-1
] # Revert the list to have the oldest messages first, as they were in the history.
if check_not_none(m.metadata_dict)["role"]
!= "system" # Do not include system message as that one is automatically in the beginning of the history.
)
Expand All @@ -261,7 +270,7 @@ def step_end_callback(agent: Agent, step_output: StepOutput) -> None:
engine_help=agent.engine.help
)
if bootstrap:
agent.bootstrap = [bootstrap]
agent.bootstrap.append(bootstrap)
return agent


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from streamlit_extras.stylable_container import stylable_container

from prediction_market_agent.agents.identifiers import AgentIdentifier
from prediction_market_agent.agents.microchain_agent.agent_functions import (
UpdateMySystemPrompt,
)
from prediction_market_agent.agents.microchain_agent.nft_functions import BalanceOfNFT
from prediction_market_agent.agents.microchain_agent.nft_treasury_game.constants_nft_treasury_game import (
NFT_TOKEN_FACTORY,
Expand All @@ -33,6 +36,7 @@
)
from prediction_market_agent.agents.microchain_agent.nft_treasury_game.messages_functions import (
BroadcastPublicMessageToHumans,
GameRoundEnd,
ReceiveMessage,
SendPaidMessageToAnotherAgent,
Wait,
Expand Down Expand Up @@ -145,6 +149,10 @@ def customized_chat_message(
icon = "😴"
case Wait.__name__:
icon = "⏳"
case UpdateMySystemPrompt.__name__:
icon = "📝"
case GameRoundEnd.__name__:
icon = "🏁"
case ReceiveMessage.__name__:
icon = "👤"
case BroadcastPublicMessageToHumans.__name__:
Expand Down Expand Up @@ -179,6 +187,8 @@ def customized_chat_message(
BroadcastPublicMessageToHumans.__name__,
SendPaidMessageToAnotherAgent.__name__,
Wait.__name__,
GameRoundEnd.__name__,
UpdateMySystemPrompt.__name__,
):
st.markdown(parsed_function_output_body)

Expand Down
Loading

0 comments on commit 43e22f4

Please sign in to comment.