Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simulate and query multiple models #71

Merged
merged 2 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ state_modifier: >
You are Talk2BioModels agent.
If the user asks for the uploaded model,
then pass the use_uploaded_model argument
as True.
as True. If the user asks for simulation,
then suggest a value for the `simulation_name`
argument.
2 changes: 1 addition & 1 deletion aiagents4pharma/talk2biomodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from . import models
from . import tools
from . import agents
from . import states
from . import states
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ class Talk2Biomodels(AgentState):
# the operator for the sbml_file_path field.
# https://langchain-ai.github.io/langgraph/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE/
sbml_file_path: Annotated[list, operator.add]
dic_simulated_data: dict
dic_simulated_data: Annotated[list[dict], operator.add]
dmccloskey marked this conversation as resolved.
Show resolved Hide resolved
llm_model: str
53 changes: 35 additions & 18 deletions aiagents4pharma/talk2biomodels/tests/test_langgraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
'''
Test cases
Test cases for Talk2Biomodels.
'''

import pandas as pd
from langchain_core.messages import HumanMessage, ToolMessage
from ..agents.t2b_agent import get_app

Expand Down Expand Up @@ -56,22 +57,33 @@ def test_ask_question_tool():

##########################################
# Test ask_question tool when simulation
# results are not available
# results are not available i.e. the
# simulation has not been run. In this
# case, the tool should return an error
##########################################
# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"})
# Define the prompt
prompt = "Call the ask_question tool to answer the "
prompt += "question: What is the concentration of CRP "
prompt += "in serum at 1000 hours?"

# Test the tool get_modelinfo
response = app.invoke(
{"messages": [HumanMessage(content=prompt)]},
config=config
)
assistant_msg = response["messages"][-1].content
# Check if the assistant message is a string
assert isinstance(assistant_msg, str)
prompt += "in serum at 1000 hours? The simulation name "
prompt += "is `simulation_name`."
# Invoke the tool
app.invoke(
{"messages": [HumanMessage(content=prompt)]},
config=config
)
# Get the messages from the current state
# and reverse the order
current_state = app.get_state(config)
reversed_messages = current_state.values["messages"][::-1]
# Loop through the reversed messages until a
# ToolMessage is found.
for msg in reversed_messages:
# Assert that the message is a ToolMessage
# and its status is "error"
if isinstance(msg, ToolMessage):
assert msg.status == "error"

def test_simulate_model_tool():
'''
Expand Down Expand Up @@ -138,9 +150,9 @@ def test_simulate_model_tool():
reversed_messages = current_state.values["messages"][::-1]
# Loop through the reversed messages
# until a ToolMessage is found.
expected_artifact = ['CRP[serum]', 'CRPExtracellular']
expected_artifact += ['CRP Suppression (%)', 'CRP (% of baseline)']
expected_artifact += ['CRP[liver]']
expected_header = ['Time', 'CRP[serum]', 'CRPExtracellular']
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
expected_header += ['CRP[liver]']
predicted_artifact = []
for msg in reversed_messages:
if isinstance(msg, ToolMessage):
Expand All @@ -150,9 +162,14 @@ def test_simulate_model_tool():
if msg.name == "custom_plotter":
predicted_artifact = msg.artifact
break
# Check if the two artifacts are equal
# assert expected_artifact in predicted_artifact
assert set(expected_artifact).issubset(set(predicted_artifact))
# Convert the artifact into a pandas dataframe
# for easy comparison
df = pd.DataFrame(predicted_artifact)
# Extract the headers from the dataframe
predicted_header = df.columns.tolist()
# Check if the header is in the expected_header
# assert expected_header in predicted_artifact
assert set(expected_header).issubset(set(predicted_header))
##########################################
# Test custom_plotter tool when the
# simulation results are available but
Expand Down
23 changes: 16 additions & 7 deletions aiagents4pharma/talk2biomodels/tools/ask_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class AskQuestionInput(BaseModel):
Input schema for the AskQuestion tool.
"""
question: str = Field(description="question about the simulation results")
simulation_name: str = Field(description="""Name assigned to the simulation
when the tool simulate_model was invoked.""")
state: Annotated[dict, InjectedState]

# Note: It's important that every field has type hints.
Expand All @@ -39,25 +41,32 @@ class AskQuestionTool(BaseTool):

def _run(self,
question: str,
simulation_name: str,
state: Annotated[dict, InjectedState]) -> str:
"""
Run the tool.

Args:
question (str): The question to ask about the simulation results.
state (dict): The state of the graph.
run_manager (Optional[CallbackManagerForToolRun]): The CallbackManagerForToolRun object.
simulation_name (str): The name assigned to the simulation.

Returns:
str: The answer to the question.
"""
logger.log(logging.INFO,
"Calling ask_question tool %s", question)
# Check if the simulation results are available
if 'dic_simulated_data' not in state:
return "Please run the simulation first before \
asking a question about the simulation results."
df = pd.DataFrame.from_dict(state['dic_simulated_data'])
"Calling ask_question tool %s, %s", question, simulation_name)
dic_simulated_data = {}
for data in state["dic_simulated_data"]:
for key in data:
if key not in dic_simulated_data:
dic_simulated_data[key] = []
dic_simulated_data[key] += [data[key]]
# print (dic_simulated_data)
df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data)
df = pd.DataFrame(
df_simulated_data[df_simulated_data['name'] == simulation_name]['data'].iloc[0]
)
prompt_content = None
# if run_manager and 'prompt' in run_manager.metadata:
# prompt_content = run_manager.metadata['prompt']
Expand Down
34 changes: 20 additions & 14 deletions aiagents4pharma/talk2biomodels/tools/custom_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@

import logging
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
from pydantic import BaseModel, Field
import pandas as pd
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.tools import BaseTool
from langgraph.prebuilt import InjectedState
from langgraph.prebuilt import InjectedState

# Initialize logger
logging.basicConfig(level=logging.INFO)
Expand All @@ -24,7 +21,7 @@ class CustomPlotterInput(BaseModel):
Input schema for the PlotImage tool.
"""
question: str = Field(description="Description of the plot")
state: Annotated[dict, InjectedState]
simulation_name: str = Field(description="Name assigned to the simulation")
state: Annotated[dict, InjectedState]

# Note: It's important that every field has type hints.
Expand All @@ -41,10 +38,10 @@ class CustomPlotterTool(BaseTool):
description: str = "A tool to make custom plots of the simulation results"
args_schema: Type[BaseModel] = CustomPlotterInput
response_format: str = "content_and_artifact"
response_format: str = "content_and_artifact"

def _run(self,
question: str,
simulation_name: str,
state: Annotated[dict, InjectedState]
) -> Tuple[str, Union[None, List[str]]]:
"""
Expand All @@ -53,17 +50,24 @@ def _run(self,
Args:
question (str): The question about the custom plot.
state (dict): The state of the graph.
question (str): The question about the custom plot.
state (dict): The state of the graph.

Returns:
str: The answer to the question
"""
logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
# Check if the simulation results are available
# if 'dic_simulated_data' not in state:
# return "Please run the simulation first before plotting the figure.", None
df = pd.DataFrame.from_dict(state['dic_simulated_data'])
dic_simulated_data = {}
for data in state["dic_simulated_data"]:
for key in data:
if key not in dic_simulated_data:
dic_simulated_data[key] = []
dic_simulated_data[key] += [data[key]]
# Create a pandas dataframe from the dictionary
df = pd.DataFrame.from_dict(dic_simulated_data)
# Get the simulated data for the current tool call
df = pd.DataFrame(
df[df['name'] == simulation_name]['data'].iloc[0]
)
# df = pd.DataFrame.from_dict(state['dic_simulated_data'])
species_names = df.columns.tolist()
# Exclude the time column
species_names.remove('Time')
Expand All @@ -76,7 +80,8 @@ class CustomHeader(TypedDict):
A list of species based on user question.
"""
relevant_species: Union[None, List[Literal[*species_names]]] = Field(
description="List of species based on user question. If no relevant species are found, it will be None.")
description="""List of species based on user question.
If no relevant species are found, it will be None.""")
# Create an instance of the LLM model
llm = ChatOpenAI(model=state['llm_model'], temperature=0)
llm_with_structured_output = llm.with_structured_output(CustomHeader)
Expand All @@ -90,5 +95,6 @@ class CustomHeader(TypedDict):
logger.info("Extracted species: %s", extracted_species)
if len(extracted_species) == 0:
return "No species found in the simulation results that matches the user prompt.", None
content = f"Plotted custom figure with species: {', '.join(extracted_species)}"
return content, extracted_species
# Include the time column
extracted_species.insert(0, 'Time')
return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
12 changes: 6 additions & 6 deletions aiagents4pharma/talk2biomodels/tools/get_modelinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class RequestedModelInfo:
"""
Dataclass for storing the requested model information.
"""
species: bool = Field(description="Get species from the model.")
parameters: bool = Field(description="Get parameters from the model.")
compartments: bool = Field(description="Get compartments from the model.")
units: bool = Field(description="Get units from the model.")
description: bool = Field(description="Get description from the model.")
name: bool = Field(description="Get name from the model.")
species: bool = Field(description="Get species from the model.", default=False)
parameters: bool = Field(description="Get parameters from the model.", default=False)
compartments: bool = Field(description="Get compartments from the model.", default=False)
units: bool = Field(description="Get units from the model.", default=False)
description: bool = Field(description="Get description from the model.", default=False)
name: bool = Field(description="Get name from the model.", default=False)

class GetModelInfoInput(BaseModel):
"""
Expand Down
30 changes: 22 additions & 8 deletions aiagents4pharma/talk2biomodels/tools/simulate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class ArgumentData:
recurring_data: RecurringData = Field(
description="species and time data on recurring basis",
default=None)
simulation_name: str = Field(
description="""An AI assigned `_` separated name of
the simulation based on human query""")

def add_rec_events(model_object, recurring_data):
"""
Expand All @@ -86,9 +89,12 @@ class SimulateModelInput(BaseModel):
"""
Input schema for the SimulateModel tool.
"""
sys_bio_model: ModelData = Field(description="model data", default=None)
arg_data: ArgumentData = Field(description="time, species, and recurring data",
default=None)
sys_bio_model: ModelData = Field(description="model data",
default=None)
arg_data: ArgumentData = Field(description=
"""time, species, and recurring data
dmccloskey marked this conversation as resolved.
Show resolved Hide resolved
as well as the simulation name""",
default=None)
tool_call_id: Annotated[str, InjectedToolCallId]
state: Annotated[dict, InjectedState]

Expand Down Expand Up @@ -153,24 +159,32 @@ def _run(self,
interval=interval
)

dic_simulated_data = {
'name': arg_data.simulation_name,
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
dmccloskey marked this conversation as resolved.
Show resolved Hide resolved
'tool_call_id': tool_call_id,
'data': df.to_dict()
}

# Prepare the dictionary of updated state for the model
dic_updated_state_for_model = {}
for key, value in {
"model_id": [sys_bio_model.biomodel_id],
"sbml_file_path": [sbml_file_path],
}.items():
"model_id": [sys_bio_model.biomodel_id],
"sbml_file_path": [sbml_file_path],
"dic_simulated_data": [dic_simulated_data],
}.items():
if value:
dic_updated_state_for_model[key] = value

# Return the updated state of the tool
return Command(
update=dic_updated_state_for_model|{
# update the state keys
"dic_simulated_data": df.to_dict(),
# "dic_simulated_data": df.to_dict(),
# update the message history
"messages": [
ToolMessage(
content="Simulation results are ready.",
content=f"Simulation results of {arg_data.simulation_name}",
tool_call_id=tool_call_id
)
],
Expand Down
Loading
Loading