From dcc28c12b3ed01519449926c13f50eb91017cf5c Mon Sep 17 00:00:00 2001 From: virattt Date: Sun, 1 Dec 2024 12:26:03 -0500 Subject: [PATCH] Add show-decisions parameter to print individual agent decisions --- src/agents.py | 60 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/src/agents.py b/src/agents.py index 696149a8..14bb209c 100644 --- a/src/agents.py +++ b/src/agents.py @@ -37,6 +37,8 @@ def market_data_agent(state: AgentState): ##### 2. Quantitative Agent ##### def quant_agent(state: AgentState): """Analyzes technical indicators and generates trading signals.""" + show_decisions = state["messages"][0].additional_kwargs["show_decisions"] + data = state["data"] prices = data["prices"] prices_df = prices_to_df(prices) @@ -107,14 +109,15 @@ def quant_agent(state: AgentState): confidence = max(bullish_signals, bearish_signals) / total_signals # Create the quant agent's message - message_content = f""" - Trading Signal: {overall_signal} - Confidence (0-1, higher is better): {confidence:.2f} - """ + message_content = f"""Quant Trading Signal: {overall_signal} \nConfidence (0-1, higher is better): {confidence:.2f}""" message = HumanMessage( content=message_content.strip(), name="quant_agent", ) + + # Print the decision if the flag is set + if show_decisions: + show_agent_decision(message.content, "Quant Agent") return { "messages": state["messages"] + [message], @@ -124,6 +127,7 @@ def quant_agent(state: AgentState): ##### 3. Risk Management Agent ##### def risk_management_agent(state: AgentState): """Evaluates portfolio risk and sets position limits""" + show_decisions = state["messages"][0].additional_kwargs["show_decisions"] portfolio = state["messages"][0].additional_kwargs["portfolio"] last_message = state["messages"][-1] @@ -135,8 +139,8 @@ def risk_management_agent(state: AgentState): Your job is to take a look at the trading analysis and evaluate portfolio exposure and recommend position sizing. Provide the following in your output (not as a JSON): - - max_position_size: , - - risk_score: """ + Max Position Size: , + Risk Score: """ ), MessagesPlaceholder(variable_name="messages"), ( @@ -158,15 +162,21 @@ def risk_management_agent(state: AgentState): chain = risk_prompt | llm result = chain.invoke(state).content message = HumanMessage( - content=f"Here is the risk management recommendation: {result}", + content=f"{result}", name="risk_management", ) + + # Print the decision if the flag is set + if show_decisions: + show_agent_decision(message.content, "Risk Management Agent") + return {"messages": state["messages"] + [message]} ##### 4. Portfolio Management Agent ##### def portfolio_management_agent(state: AgentState): """Makes final trading decisions and generates orders""" + show_decisions = state["messages"][0].additional_kwargs["show_decisions"] portfolio = state["messages"][0].additional_kwargs["portfolio"] last_message = state["messages"][-1] @@ -208,26 +218,41 @@ def portfolio_management_agent(state: AgentState): chain = portfolio_prompt | llm result = chain.invoke(state).content - return {"messages": [HumanMessage(content=result, name="portfolio_management")]} + message = HumanMessage( + content=f"{result}", + name="portfolio_management", + ) + + # Print the decision if the flag is set + if show_decisions: + show_agent_decision(message.content, "Portfolio Management Agent") + + return {"messages": state["messages"] + [message]} + +def show_agent_decision(output, agent_name): + print(f"\n{'=' * 5} {agent_name.center(28)} {'=' * 5}") + print(output) + print("=" * 40) ##### Run the Hedge Fund ##### -def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict): +def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict, show_decisions: bool = False): final_state = app.invoke( { "messages": [ HumanMessage( content="Make a trading decision based on the provided data.", additional_kwargs={ - "ticker": ticker, - "start_date": start_date, - "end_date": end_date, - "portfolio": portfolio + "portfolio": portfolio, + "show_decisions": show_decisions, }, ) ], - "data": {"ticker": ticker, "start_date": start_date, "end_date": end_date} + "data": { + "ticker": ticker, + "start_date": start_date, + "end_date": end_date + }, }, - config={"configurable": {"thread_id": 42}}, ) return final_state["messages"][-1].content @@ -255,6 +280,7 @@ def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict) parser.add_argument('--ticker', type=str, required=True, help='Stock ticker symbol') parser.add_argument('--start-date', type=str, required=True, help='Start date (YYYY-MM-DD)') parser.add_argument('--end-date', type=str, required=True, help='End date (YYYY-MM-DD)') + parser.add_argument('--show-decisions', action='store_true', help='Show decisions from each agent') args = parser.parse_args() @@ -275,6 +301,8 @@ def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict) ticker=args.ticker, start_date=args.start_date, end_date=args.end_date, - portfolio=portfolio + portfolio=portfolio, + show_decisions=args.show_decisions ) + print("\nFinal Result:") print(result) \ No newline at end of file