Skip to content

Commit

Permalink
Add show-decisions parameter to print individual agent decisions
Browse files Browse the repository at this point in the history
  • Loading branch information
virattt committed Dec 1, 2024
1 parent 9d2fd81 commit dcc28c1
Showing 1 changed file with 44 additions and 16 deletions.
60 changes: 44 additions & 16 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def market_data_agent(state: AgentState):
##### 2. Quantitative Agent #####
def quant_agent(state: AgentState):
"""Analyzes technical indicators and generates trading signals."""
show_decisions = state["messages"][0].additional_kwargs["show_decisions"]

data = state["data"]
prices = data["prices"]
prices_df = prices_to_df(prices)
Expand Down Expand Up @@ -107,14 +109,15 @@ def quant_agent(state: AgentState):
confidence = max(bullish_signals, bearish_signals) / total_signals

# Create the quant agent's message
message_content = f"""
Trading Signal: {overall_signal}
Confidence (0-1, higher is better): {confidence:.2f}
"""
message_content = f"""Quant Trading Signal: {overall_signal} \nConfidence (0-1, higher is better): {confidence:.2f}"""
message = HumanMessage(
content=message_content.strip(),
name="quant_agent",
)

# Print the decision if the flag is set
if show_decisions:
show_agent_decision(message.content, "Quant Agent")

return {
"messages": state["messages"] + [message],
Expand All @@ -124,6 +127,7 @@ def quant_agent(state: AgentState):
##### 3. Risk Management Agent #####
def risk_management_agent(state: AgentState):
"""Evaluates portfolio risk and sets position limits"""
show_decisions = state["messages"][0].additional_kwargs["show_decisions"]
portfolio = state["messages"][0].additional_kwargs["portfolio"]
last_message = state["messages"][-1]

Expand All @@ -135,8 +139,8 @@ def risk_management_agent(state: AgentState):
Your job is to take a look at the trading analysis and
evaluate portfolio exposure and recommend position sizing.
Provide the following in your output (not as a JSON):
- max_position_size: <float greater than 0>,
- risk_score: <integer between 1 and 10>"""
Max Position Size: <float greater than 0>,
Risk Score: <integer between 1 and 10>"""
),
MessagesPlaceholder(variable_name="messages"),
(
Expand All @@ -158,15 +162,21 @@ def risk_management_agent(state: AgentState):
chain = risk_prompt | llm
result = chain.invoke(state).content
message = HumanMessage(
content=f"Here is the risk management recommendation: {result}",
content=f"{result}",
name="risk_management",
)

# Print the decision if the flag is set
if show_decisions:
show_agent_decision(message.content, "Risk Management Agent")

return {"messages": state["messages"] + [message]}


##### 4. Portfolio Management Agent #####
def portfolio_management_agent(state: AgentState):
"""Makes final trading decisions and generates orders"""
show_decisions = state["messages"][0].additional_kwargs["show_decisions"]
portfolio = state["messages"][0].additional_kwargs["portfolio"]
last_message = state["messages"][-1]

Expand Down Expand Up @@ -208,26 +218,41 @@ def portfolio_management_agent(state: AgentState):

chain = portfolio_prompt | llm
result = chain.invoke(state).content
return {"messages": [HumanMessage(content=result, name="portfolio_management")]}
message = HumanMessage(
content=f"{result}",
name="portfolio_management",
)

# Print the decision if the flag is set
if show_decisions:
show_agent_decision(message.content, "Portfolio Management Agent")

return {"messages": state["messages"] + [message]}

def show_agent_decision(output, agent_name):
print(f"\n{'=' * 5} {agent_name.center(28)} {'=' * 5}")
print(output)
print("=" * 40)

##### Run the Hedge Fund #####
def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict):
def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict, show_decisions: bool = False):
final_state = app.invoke(
{
"messages": [
HumanMessage(
content="Make a trading decision based on the provided data.",
additional_kwargs={
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"portfolio": portfolio
"portfolio": portfolio,
"show_decisions": show_decisions,
},
)
],
"data": {"ticker": ticker, "start_date": start_date, "end_date": end_date}
"data": {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date
},
},
config={"configurable": {"thread_id": 42}},
)
return final_state["messages"][-1].content

Expand Down Expand Up @@ -255,6 +280,7 @@ def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict)
parser.add_argument('--ticker', type=str, required=True, help='Stock ticker symbol')
parser.add_argument('--start-date', type=str, required=True, help='Start date (YYYY-MM-DD)')
parser.add_argument('--end-date', type=str, required=True, help='End date (YYYY-MM-DD)')
parser.add_argument('--show-decisions', action='store_true', help='Show decisions from each agent')

args = parser.parse_args()

Expand All @@ -275,6 +301,8 @@ def run_hedge_fund(ticker: str, start_date: str, end_date: str, portfolio: dict)
ticker=args.ticker,
start_date=args.start_date,
end_date=args.end_date,
portfolio=portfolio
portfolio=portfolio,
show_decisions=args.show_decisions
)
print("\nFinal Result:")
print(result)

0 comments on commit dcc28c1

Please sign in to comment.