From 65a8f78e625c9989a213f608d412677afcf26d67 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Wed, 11 Sep 2024 10:07:14 -0400 Subject: [PATCH 01/37] updating dependencies for github workflows --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index eae71c15..86890452 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +black[jupyter]==24.2.0 +blacken-docs +pre-commit pytest==7.3.2 pytest-xdist pytest-playwright From ab0f4ec2f06e5f0becbc8379a98611ffb8f8aab6 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Wed, 18 Sep 2024 12:12:06 -0400 Subject: [PATCH 02/37] openrouter tracker poc --- .../launch_command.py => launch_command.py | 11 +- src/agentlab/agents/generic_agent/__init__.py | 2 + .../agents/generic_agent/agent_configs.py | 8 +- .../agents/generic_agent/generic_agent.py | 3 +- src/agentlab/llm/chat_api.py | 3 +- src/agentlab/llm/tracking.py | 104 ++++++++++++++++++ 6 files changed, 120 insertions(+), 11 deletions(-) rename src/agentlab/experiments/launch_command.py => launch_command.py (76%) create mode 100644 src/agentlab/llm/tracking.py diff --git a/src/agentlab/experiments/launch_command.py b/launch_command.py similarity index 76% rename from src/agentlab/experiments/launch_command.py rename to launch_command.py index 01b48a7f..198bbe07 100644 --- a/src/agentlab/experiments/launch_command.py +++ b/launch_command.py @@ -7,21 +7,22 @@ import logging -from agentlab.agents.generic_agent import RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI +from agentlab.agents.generic_agent import AGENT_CUSTOM, RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI from agentlab.analyze.inspect_results import get_most_recent_folder from agentlab.experiments import study_generators from agentlab.experiments.exp_utils import RESULTS_DIR -from agentlab.experiments.launch_exp import make_study_dir, run_experiments, relaunch_study +from agentlab.experiments.launch_exp import make_study_dir, relaunch_study, run_experiments logging.getLogger().setLevel(logging.INFO) # choose your agent or provide a new agent -agent_args = AGENT_4o_MINI +agent_args = [AGENT_CUSTOM, AGENT_4o_MINI] # agent = AGENT_4o ## select the benchmark to run on benchmark = "miniwob" +benchmark = "miniwob_tiny_test" # benchmark = "workarena.l1" # benchmark = "workarena.l2" # benchmark = "workarena.l3" @@ -37,8 +38,8 @@ ## alternatively, relaunch an existing study -study_dir = get_most_recent_folder(RESULTS_DIR, contains=None) -exp_args_list, study_dir = relaunch_study(study_dir, relaunch_mode="incomplete_or_error") +# study_dir = get_most_recent_folder(RESULTS_DIR, contains=None) +# exp_args_list, study_dir = relaunch_study(study_dir, relaunch_mode="incomplete_or_error") ## Number of parallel jobs diff --git a/src/agentlab/agents/generic_agent/__init__.py b/src/agentlab/agents/generic_agent/__init__.py index d9839c4d..fec74910 100644 --- a/src/agentlab/agents/generic_agent/__init__.py +++ b/src/agentlab/agents/generic_agent/__init__.py @@ -2,6 +2,7 @@ AGENT_3_5, AGENT_8B, AGENT_70B, + AGENT_CUSTOM, RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI, @@ -16,4 +17,5 @@ "AGENT_70B", "AGENT_8B", "RANDOM_SEARCH_AGENT", + "AGENT_CUSTOM", ] diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index a53046b2..da0a6799 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -1,9 +1,9 @@ -from .generic_agent_prompt import GenericPromptFlags from agentlab.agents import dynamic_prompting as dp -from .generic_agent import GenericAgentArgs -from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from agentlab.experiments import args +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT +from .generic_agent import GenericAgentArgs +from .generic_agent_prompt import GenericPromptFlags FLAGS_CUSTOM = GenericPromptFlags( obs=dp.ObsFlags( @@ -45,7 +45,7 @@ AGENT_CUSTOM = GenericAgentArgs( - chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-3.5-turbo-1106"], + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"], flags=FLAGS_CUSTOM, ) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index a53f1aeb..eb3b0d87 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -10,6 +10,7 @@ from agentlab.agents.utils import openai_monitored_agent from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import RetryError, retry_raise +from agentlab.llm.tracking import get_action_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt @@ -65,7 +66,7 @@ def __init__( def obs_preprocessor(self, obs: dict) -> dict: return self._obs_preprocessor(obs) - @openai_monitored_agent + @get_action_decorator def get_action(self, obs): self.obs_history.append(obs) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 2f425da7..1246081a 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -12,6 +12,7 @@ HuggingFaceAPIChatModel, HuggingFaceURLChatModel, ) +from agentlab.llm.tracking import OpenRouterChatModel if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel @@ -86,7 +87,7 @@ class OpenRouterModelArgs(BaseModelArgs): model.""" def make_model(self): - return ChatOpenRouter( + return OpenRouterChatModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py new file mode 100644 index 00000000..13b33455 --- /dev/null +++ b/src/agentlab/llm/tracking.py @@ -0,0 +1,104 @@ +import os +from contextlib import contextmanager +from typing import Any, List, Optional + +import requests +from langchain.schema import AIMessage, BaseMessage +from openai import OpenAI + +from agentlab.llm.langchain_utils import _convert_messages_to_dict + + +class LLMTracker: + def __init__(self): + self.input_tokens = 0 + self.output_tokens = 0 + self.cost = 0 + + def __call__(self, input_tokens: int, output_tokens: int, cost: float): + self.input_tokens += input_tokens + self.output_tokens += output_tokens + self.cost += cost + + @property + def stats(self): + return { + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "cost": self.cost, + } + + +@contextmanager +def set_tracker(tracker: LLMTracker): + global current_tracker + previous_tracker = globals().get("current_tracker", None) + current_tracker = tracker + yield + current_tracker = previous_tracker + + +def get_action_decorator(get_action): + def wrapper(self, obs): + tracker = LLMTracker() + with set_tracker(tracker): + action, agent_info = get_action(self, obs) + agent_info.get("stats").update(tracker.stats) + return action, agent_info + + return wrapper + + +class OpenRouterChatModel: + def __init__( + self, + model_name, + openrouter_api_key=None, + openrouter_api_base="https://openrouter.ai/api/v1", + temperature=0.5, + max_tokens=100, + ): + self.model_name = model_name + self.openrouter_api_key = openrouter_api_key + self.openrouter_api_base = openrouter_api_base + self.temperature = temperature + self.max_tokens = max_tokens + + openrouter_api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY") + + # query api to get model metadata + url = "https://openrouter.ai/api/v1/models" + headers = {"Authorization": f"Bearer {openrouter_api_key}"} + response = requests.get(url, headers=headers) + + if response.status_code != 200: + raise ValueError("Failed to get model metadata") + + model_metadata = response.json() + pricings = {model["id"]: model["pricing"] for model in model_metadata["data"]} + + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + + self.client = OpenAI( + base_url=openrouter_api_base, + api_key=openrouter_api_key, + ) + + def __call__(self, messages: List[BaseMessage]) -> str: + messages_formated = _convert_messages_to_dict(messages) + completion = self.client.chat.completions.create( + model=self.model_name, messages=messages_formated + ) + input_tokens = completion.usage.prompt_tokens + output_tokens = completion.usage.completion_tokens + cost = input_tokens * self.input_cost + output_tokens * self.output_cost + + global current_tracker + if "current_tracker" in globals() and isinstance(current_tracker, LLMTracker): + current_tracker(input_tokens, output_tokens, cost) + + return AIMessage(content=completion.choices[0].message.content) + + def invoke(self, messages: List[BaseMessage]) -> AIMessage: + return self(messages) From ca50598a2d1efd3719f059745dc07c6dc815805f Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Wed, 18 Sep 2024 14:47:14 -0400 Subject: [PATCH 03/37] adding openai pricing request --- src/agentlab/llm/chat_api.py | 4 +- src/agentlab/llm/tracking.py | 159 +++++++++++++++++++++++++++++------ 2 files changed, 134 insertions(+), 29 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 1246081a..035838aa 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -12,7 +12,7 @@ HuggingFaceAPIChatModel, HuggingFaceURLChatModel, ) -from agentlab.llm.tracking import OpenRouterChatModel +from agentlab.llm.tracking import OpenAIChatModel, OpenRouterChatModel if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel @@ -100,7 +100,7 @@ class OpenAIModelArgs(BaseModelArgs): model.""" def make_model(self): - return ChatOpenAI( + return OpenAIChatModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 13b33455..fcf92ae6 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,10 +1,12 @@ +import ast import os +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, List, Optional import requests from langchain.schema import AIMessage, BaseMessage -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from agentlab.llm.langchain_utils import _convert_messages_to_dict @@ -49,46 +51,75 @@ def wrapper(self, obs): return wrapper -class OpenRouterChatModel: - def __init__( - self, - model_name, - openrouter_api_key=None, - openrouter_api_base="https://openrouter.ai/api/v1", - temperature=0.5, - max_tokens=100, - ): - self.model_name = model_name - self.openrouter_api_key = openrouter_api_key - self.openrouter_api_base = openrouter_api_base - self.temperature = temperature - self.max_tokens = max_tokens - - openrouter_api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY") - +def get_pricing(api: str = "openrouter", api_key: str = None): + if api == "openrouter": + assert api_key, "OpenRouter API key is required" # query api to get model metadata url = "https://openrouter.ai/api/v1/models" - headers = {"Authorization": f"Bearer {openrouter_api_key}"} + headers = {"Authorization": f"Bearer {api_key}"} response = requests.get(url, headers=headers) if response.status_code != 200: raise ValueError("Failed to get model metadata") model_metadata = response.json() - pricings = {model["id"]: model["pricing"] for model in model_metadata["data"]} + return { + model["id"]: {k: float(v) for k, v in model["pricing"].items()} + for model in model_metadata["data"] + } + elif api == "openai": + url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/community/langchain_community/callbacks/openai_info.py" + response = requests.get(url) + + if response.status_code == 200: + content = response.text + tree = ast.parse(content) + cost_dict = None + for node in tree.body: + if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name): + if node.targets[0].id == "MODEL_COST_PER_1K_TOKENS": + cost_dict = ast.literal_eval(node.value) + break + if cost_dict: + cost_dict = {k: v / 1000 for k, v in cost_dict.items()} + res = {} + for k in cost_dict: + if k.endswith("-completion"): + continue + prompt_key = k + completion_key = k + "-completion" + if completion_key in cost_dict: + res[k] = { + "prompt": cost_dict[prompt_key], + "completion": cost_dict[completion_key], + } + return res + else: + raise ValueError("Cost dictionary not found.") + else: + raise ValueError(f"Failed to retrieve the file. Status code: {response.status_code}") + + +class ChatModel(ABC): + + @abstractmethod + def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens - self.input_cost = float(pricings[model_name]["prompt"]) - self.output_cost = float(pricings[model_name]["completion"]) + self.client = OpenAI() - self.client = OpenAI( - base_url=openrouter_api_base, - api_key=openrouter_api_key, - ) + self.input_cost = 0.0 + self.output_cost = 0.0 def __call__(self, messages: List[BaseMessage]) -> str: messages_formated = _convert_messages_to_dict(messages) completion = self.client.chat.completions.create( - model=self.model_name, messages=messages_formated + model=self.model_name, + messages=messages_formated, + temperature=self.temperature, + max_tokens=self.max_tokens, ) input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens @@ -102,3 +133,77 @@ def __call__(self, messages: List[BaseMessage]) -> str: def invoke(self, messages: List[BaseMessage]) -> AIMessage: return self(messages) + + +class OpenRouterChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + ): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + + api_key = api_key or os.getenv("OPENROUTER_API_KEY") + + pricings = get_pricing(api="openrouter", api_key=api_key) + + self.input_cost = pricings[model_name]["prompt"] + self.output_cost = pricings[model_name]["completion"] + + self.client = OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=api_key, + ) + + +class OpenAIChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + ): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + + api_key = api_key or os.getenv("OPENAI_API_KEY") + + pricings = get_pricing(api="openai") + + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + + self.client = OpenAI( + api_key=api_key, + ) + + +class AzureChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + endpoint=None, + temperature=0.5, + max_tokens=100, + ): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + + api_key = api_key or os.getenv("OPENAI_API_KEY") + + pricings = get_pricing(api="openai") + + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + + self.client = AzureOpenAI( + api_key=api_key, azure_endpoint=endpoint, api_version="2024-02-01" + ) From 0cde22b6b76f2c82ebfffe24bfb17127f0d334b3 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Thu, 19 Sep 2024 10:56:53 -0400 Subject: [PATCH 04/37] switching back to langchain community for openai pricing --- src/agentlab/llm/tracking.py | 45 ++++++++++++------------------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index fcf92ae6..32777904 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -6,6 +6,7 @@ import requests from langchain.schema import AIMessage, BaseMessage +from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS from openai import AzureOpenAI, OpenAI from agentlab.llm.langchain_utils import _convert_messages_to_dict @@ -68,36 +69,20 @@ def get_pricing(api: str = "openrouter", api_key: str = None): for model in model_metadata["data"] } elif api == "openai": - url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/community/langchain_community/callbacks/openai_info.py" - response = requests.get(url) - - if response.status_code == 200: - content = response.text - tree = ast.parse(content) - cost_dict = None - for node in tree.body: - if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name): - if node.targets[0].id == "MODEL_COST_PER_1K_TOKENS": - cost_dict = ast.literal_eval(node.value) - break - if cost_dict: - cost_dict = {k: v / 1000 for k, v in cost_dict.items()} - res = {} - for k in cost_dict: - if k.endswith("-completion"): - continue - prompt_key = k - completion_key = k + "-completion" - if completion_key in cost_dict: - res[k] = { - "prompt": cost_dict[prompt_key], - "completion": cost_dict[completion_key], - } - return res - else: - raise ValueError("Cost dictionary not found.") - else: - raise ValueError(f"Failed to retrieve the file. Status code: {response.status_code}") + cost_dict = MODEL_COST_PER_1K_TOKENS + cost_dict = {k: v / 1000 for k, v in cost_dict.items()} + res = {} + for k in cost_dict: + if k.endswith("-completion"): + continue + prompt_key = k + completion_key = k + "-completion" + if completion_key in cost_dict: + res[k] = { + "prompt": cost_dict[prompt_key], + "completion": cost_dict[completion_key], + } + return res class ChatModel(ABC): From 060e1e74b90bef6fcd18a6dc32a4d1307abf5430 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Thu, 19 Sep 2024 11:02:53 -0400 Subject: [PATCH 05/37] renaming launch_command.py to main.py --- launch_command.py => main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) rename launch_command.py => main.py (93%) diff --git a/launch_command.py b/main.py similarity index 93% rename from launch_command.py rename to main.py index 198bbe07..bfb696c2 100644 --- a/launch_command.py +++ b/main.py @@ -16,7 +16,7 @@ logging.getLogger().setLevel(logging.INFO) # choose your agent or provide a new agent -agent_args = [AGENT_CUSTOM, AGENT_4o_MINI] +agent_args = [AGENT_4o_MINI] # agent = AGENT_4o @@ -48,4 +48,5 @@ # run the experiments -run_experiments(n_jobs, exp_args_list, study_dir) +if __name__ == "__main__": + run_experiments(n_jobs, exp_args_list, study_dir) From 600cfcacf17c1703b7855d93d3d95cc786f28121 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Thu, 19 Sep 2024 11:09:50 -0400 Subject: [PATCH 06/37] typo --- main.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 5bdef316..e6234637 100644 --- a/main.py +++ b/main.py @@ -7,13 +7,11 @@ import logging -from agentlab.agents.generic_agent import (AGENT_CUSTOM, RANDOM_SEARCH_AGENT, - AGENT_4o, AGENT_4o_MINI) +from agentlab.agents.generic_agent import AGENT_CUSTOM, RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI from agentlab.analyze.inspect_results import get_most_recent_folder from agentlab.experiments import study_generators from agentlab.experiments.exp_utils import RESULTS_DIR -from agentlab.experiments.launch_exp import (make_study_dir, relaunch_study, - run_experiments) +from agentlab.experiments.launch_exp import make_study_dir, relaunch_study, run_experiments logging.getLogger().setLevel(logging.INFO) @@ -50,5 +48,4 @@ # run the experiments if __name__ == "__main__": - run_experiments(n_jobs, exp_args_list, study_dir)if __name__ == "__main__": - run_experiments(n_jobs, exp_args_list, study_dir) \ No newline at end of file + run_experiments(n_jobs, exp_args_list, study_dir) From 708bde5719c738e1b62a3489f8d47fb37aa0fa69 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 10:30:04 -0400 Subject: [PATCH 07/37] tracking is thread safe and mostly tested --- src/agentlab/llm/tracking.py | 52 +++++++++++----- tests/llm/test_tracking.py | 112 +++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 15 deletions(-) create mode 100644 tests/llm/test_tracking.py diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 32777904..d964c156 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,5 +1,6 @@ import ast import os +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, List, Optional @@ -11,12 +12,14 @@ from agentlab.llm.langchain_utils import _convert_messages_to_dict +TRACKER = threading.local() + class LLMTracker: def __init__(self): self.input_tokens = 0 self.output_tokens = 0 - self.cost = 0 + self.cost = 0.0 def __call__(self, input_tokens: int, output_tokens: int, cost: float): self.input_tokens += input_tokens @@ -31,20 +34,33 @@ def stats(self): "cost": self.cost, } + def add_tracker(self, tracker: "LLMTracker"): + self(tracker.input_tokens, tracker.output_tokens, tracker.cost) + + def __repr__(self): + return f"LLMTracker(input_tokens={self.input_tokens}, output_tokens={self.output_tokens}, cost={self.cost})" + @contextmanager -def set_tracker(tracker: LLMTracker): - global current_tracker - previous_tracker = globals().get("current_tracker", None) - current_tracker = tracker - yield - current_tracker = previous_tracker +def set_tracker(): + global TRACKER + if not hasattr(TRACKER, "instance"): + TRACKER.instance = None + previous_tracker = TRACKER.instance # type: LLMTracker + TRACKER.instance = LLMTracker() + try: + yield TRACKER.instance + finally: + # If there was a previous tracker, add the current one to it + if isinstance(previous_tracker, LLMTracker): + previous_tracker.add_tracker(TRACKER.instance) + # Restore the previous tracker + TRACKER.instance = previous_tracker def get_action_decorator(get_action): def wrapper(self, obs): - tracker = LLMTracker() - with set_tracker(tracker): + with set_tracker() as tracker: action, agent_info = get_action(self, obs) agent_info.get("stats").update(tracker.stats) return action, agent_info @@ -110,9 +126,8 @@ def __call__(self, messages: List[BaseMessage]) -> str: output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - global current_tracker - if "current_tracker" in globals() and isinstance(current_tracker, LLMTracker): - current_tracker(input_tokens, output_tokens, cost) + if isinstance(TRACKER.instance, LLMTracker): + TRACKER.instance(input_tokens, output_tokens, cost) return AIMessage(content=completion.choices[0].message.content) @@ -174,7 +189,7 @@ def __init__( self, model_name, api_key=None, - endpoint=None, + deployment_name=None, temperature=0.5, max_tokens=100, ): @@ -182,7 +197,11 @@ def __init__( self.temperature = temperature self.max_tokens = max_tokens - api_key = api_key or os.getenv("OPENAI_API_KEY") + api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") + + # AZURE_OPENAI_ENDPOINT has to be defined in the environment + endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment" pricings = get_pricing(api="openai") @@ -190,5 +209,8 @@ def __init__( self.output_cost = float(pricings[model_name]["completion"]) self.client = AzureOpenAI( - api_key=api_key, azure_endpoint=endpoint, api_version="2024-02-01" + api_key=api_key, + azure_deployment=deployment_name, + azure_endpoint=endpoint, + api_version="2024-02-01", ) diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py new file mode 100644 index 00000000..5f12cd68 --- /dev/null +++ b/tests/llm/test_tracking.py @@ -0,0 +1,112 @@ +import os +import time +from functools import partial + +import pytest + +import agentlab.llm.tracking as tracking + + +def test_get_action_decorator(): + action, agent_info = tracking.get_action_decorator(lambda x, y: call_llm())(None, None) + assert action == "action" + assert agent_info["stats"] == { + "input_tokens": 1, + "output_tokens": 1, + "cost": 1.0, + } + + +OPENROUTER_API_KEY_AVAILABLE = os.environ.get("OPENROUTER_API_KEY") is not None + +OPENROUTER_MODELS = ( + "openai/o1-mini-2024-09-12", + "openai/o1-preview-2024-09-12", + "openai/gpt-4o-2024-08-06", + "openai/gpt-4o-2024-05-13", + "anthropic/claude-3.5-sonnet:beta", + "anthropic/claude-3.5-sonnet", + "meta-llama/llama-3.1-405b-instruct", + "meta-llama/llama-3.1-70b-instruct", + "meta-llama/llama-3.1-8b-instruct", + "google/gemini-pro-1.5", + "qwen/qwen-2-vl-72b-instruct", +) + + +@pytest.mark.skipif(not OPENROUTER_API_KEY_AVAILABLE, reason="OpenRouter API key is not available") +def test_get_pricing_openrouter(): + pricing = tracking.get_pricing(api="openrouter", api_key=os.environ["OPENROUTER_API_KEY"]) + assert isinstance(pricing, dict) + assert all(isinstance(v, dict) for v in pricing.values()) + for model in OPENROUTER_MODELS: + assert model in pricing + assert isinstance(pricing[model], dict) + assert all(isinstance(v, float) for v in pricing[model].values()) + + +def test_get_pricing_openai(): + pricing = tracking.get_pricing(api="openai") + assert isinstance(pricing, dict) + assert all("prompt" in pricing[model] and "completion" in pricing[model] for model in pricing) + assert all(isinstance(pricing[model]["prompt"], float) for model in pricing) + assert all(isinstance(pricing[model]["completion"], float) for model in pricing) + + +def call_llm(): + if isinstance(tracking.TRACKER.instance, tracking.LLMTracker): + tracking.TRACKER.instance(1, 1, 1) + return "action", {"stats": {}} + + +def test_tracker(): + with tracking.set_tracker() as tracker: + _, _ = call_llm() + + assert tracker.stats["cost"] == 1 + + +def test_imbricate_trackers(): + with tracking.set_tracker() as tracker4: + with tracking.set_tracker() as tracker1: + _, _ = call_llm() + with tracking.set_tracker() as tracker3: + _, _ = call_llm() + _, _ = call_llm() + with tracking.set_tracker() as tracker1bis: + _, _ = call_llm() + + assert tracker1.stats["cost"] == 1 + assert tracker1bis.stats["cost"] == 1 + assert tracker3.stats["cost"] == 3 + assert tracker4.stats["cost"] == 4 + + +def test_threaded_trackers(): + """thread_2 occurs in the middle of thread_1, results should be separate.""" + import threading + + def thread_1(results=None): + with tracking.set_tracker() as tracker: + time.sleep(1) + _, _ = call_llm() + time.sleep(1) + results[0] = tracker.stats + + def thread_2(results=None): + time.sleep(1) + with tracking.set_tracker() as tracker: + _, _ = call_llm() + results[1] = tracker.stats + + results = [None] * 2 + threads = [ + threading.Thread(target=partial(thread_1, results=results)), + threading.Thread(target=partial(thread_2, results=results)), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert all(result["cost"] == 1 for result in results) From 7c209459e067ac76e21543fc2f72bf70db900c39 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 10:48:38 -0400 Subject: [PATCH 08/37] added pricy tests for ChatModels --- tests/llm/test_tracking.py | 68 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index 5f12cd68..a8141975 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -110,3 +110,71 @@ def thread_2(results=None): thread.join() assert all(result["cost"] == 1 for result in results) + + +OPENAI_API_KEY_AVAILABLE = os.environ.get("OPENAI_API_KEY") is not None + + +@pytest.mark.pricy +@pytest.mark.skipif(not OPENAI_API_KEY_AVAILABLE, reason="OpenAI API key is not available") +def test_openai_chat_model(): + chat_model = tracking.OpenAIChatModel("gpt-4o-mini") + assert chat_model.input_cost > 0 + assert chat_model.output_cost > 0 + + from langchain.schema import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are an helpful virtual assistant"), + HumanMessage(content="Give the third prime number"), + ] + with tracking.set_tracker() as tracker: + answer = chat_model.invoke(messages) + assert "5" in answer.content + assert tracker.stats["cost"] > 0 + + +AZURE_OPENAI_API_KEY_AVAILABLE = ( + os.environ.get("AZURE_OPENAI_API_KEY") is not None + and os.environ.get("AZURE_OPENAI_ENDPOINT") is not None +) + + +@pytest.mark.pricy +@pytest.mark.skipif( + not AZURE_OPENAI_API_KEY_AVAILABLE, reason="Azure OpenAI API key is not available" +) +def test_azure_chat_model(): + chat_model = tracking.AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo") + assert chat_model.input_cost > 0 + assert chat_model.output_cost > 0 + + from langchain.schema import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are an helpful virtual assistant"), + HumanMessage(content="Give the third prime number"), + ] + with tracking.set_tracker() as tracker: + answer = chat_model.invoke(messages) + assert "5" in answer.content + assert tracker.stats["cost"] > 0 + + +@pytest.mark.pricy +@pytest.mark.skipif(not OPENROUTER_API_KEY_AVAILABLE, reason="OpenRouter API key is not available") +def test_openrouter_chat_model(): + chat_model = tracking.OpenRouterChatModel("openai/gpt-4o-mini") + assert chat_model.input_cost > 0 + assert chat_model.output_cost > 0 + + from langchain.schema import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are an helpful virtual assistant"), + HumanMessage(content="Give the third prime number"), + ] + with tracking.set_tracker() as tracker: + answer = chat_model.invoke(messages) + assert "5" in answer.content + assert tracker.stats["cost"] > 0 From 18a45e0ba1361daafc0f83a8c670a3e95a27fad0 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 10:51:44 -0400 Subject: [PATCH 09/37] separating get_pricing function --- src/agentlab/llm/tracking.py | 70 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index d964c156..54ed871d 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -68,37 +68,39 @@ def wrapper(self, obs): return wrapper -def get_pricing(api: str = "openrouter", api_key: str = None): - if api == "openrouter": - assert api_key, "OpenRouter API key is required" - # query api to get model metadata - url = "https://openrouter.ai/api/v1/models" - headers = {"Authorization": f"Bearer {api_key}"} - response = requests.get(url, headers=headers) - - if response.status_code != 200: - raise ValueError("Failed to get model metadata") - - model_metadata = response.json() - return { - model["id"]: {k: float(v) for k, v in model["pricing"].items()} - for model in model_metadata["data"] - } - elif api == "openai": - cost_dict = MODEL_COST_PER_1K_TOKENS - cost_dict = {k: v / 1000 for k, v in cost_dict.items()} - res = {} - for k in cost_dict: - if k.endswith("-completion"): - continue - prompt_key = k - completion_key = k + "-completion" - if completion_key in cost_dict: - res[k] = { - "prompt": cost_dict[prompt_key], - "completion": cost_dict[completion_key], - } - return res +def get_pricing_openrouter(): + api_key = os.getenv("OPENROUTER_API_KEY") + assert api_key, "OpenRouter API key is required" + # query api to get model metadata + url = "https://openrouter.ai/api/v1/models" + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get(url, headers=headers) + + if response.status_code != 200: + raise ValueError("Failed to get model metadata") + + model_metadata = response.json() + return { + model["id"]: {k: float(v) for k, v in model["pricing"].items()} + for model in model_metadata["data"] + } + + +def get_pricing_openai(): + cost_dict = MODEL_COST_PER_1K_TOKENS + cost_dict = {k: v / 1000 for k, v in cost_dict.items()} + res = {} + for k in cost_dict: + if k.endswith("-completion"): + continue + prompt_key = k + completion_key = k + "-completion" + if completion_key in cost_dict: + res[k] = { + "prompt": cost_dict[prompt_key], + "completion": cost_dict[completion_key], + } + return res class ChatModel(ABC): @@ -149,7 +151,7 @@ def __init__( api_key = api_key or os.getenv("OPENROUTER_API_KEY") - pricings = get_pricing(api="openrouter", api_key=api_key) + pricings = get_pricing_openrouter() self.input_cost = pricings[model_name]["prompt"] self.output_cost = pricings[model_name]["completion"] @@ -174,7 +176,7 @@ def __init__( api_key = api_key or os.getenv("OPENAI_API_KEY") - pricings = get_pricing(api="openai") + pricings = get_pricing_openai() self.input_cost = float(pricings[model_name]["prompt"]) self.output_cost = float(pricings[model_name]["completion"]) @@ -203,7 +205,7 @@ def __init__( endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment" - pricings = get_pricing(api="openai") + pricings = get_pricing_openai() self.input_cost = float(pricings[model_name]["prompt"]) self.output_cost = float(pricings[model_name]["completion"]) From 9d12cdfa5136095240a90c43812a9b0495bf461a Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 10:57:11 -0400 Subject: [PATCH 10/37] updating function names --- src/agentlab/llm/tracking.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 54ed871d..8a530e73 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -3,7 +3,6 @@ import threading from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional import requests from langchain.schema import AIMessage, BaseMessage @@ -116,11 +115,10 @@ def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100): self.input_cost = 0.0 self.output_cost = 0.0 - def __call__(self, messages: List[BaseMessage]) -> str: - messages_formated = _convert_messages_to_dict(messages) + def __call__(self, messages: list[dict]) -> dict: completion = self.client.chat.completions.create( model=self.model_name, - messages=messages_formated, + messages=messages, temperature=self.temperature, max_tokens=self.max_tokens, ) @@ -131,9 +129,9 @@ def __call__(self, messages: List[BaseMessage]) -> str: if isinstance(TRACKER.instance, LLMTracker): TRACKER.instance(input_tokens, output_tokens, cost) - return AIMessage(content=completion.choices[0].message.content) + return dict(role="assistant", content=completion.choices[0].message.content) - def invoke(self, messages: List[BaseMessage]) -> AIMessage: + def invoke(self, messages: list[dict]) -> dict: return self(messages) From 8e5e5f939030df6ba8147e823a6cda44d0e7ad0e Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 11:01:20 -0400 Subject: [PATCH 11/37] updating function names --- tests/llm/test_tracking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index a8141975..4672cacd 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -36,7 +36,7 @@ def test_get_action_decorator(): @pytest.mark.skipif(not OPENROUTER_API_KEY_AVAILABLE, reason="OpenRouter API key is not available") def test_get_pricing_openrouter(): - pricing = tracking.get_pricing(api="openrouter", api_key=os.environ["OPENROUTER_API_KEY"]) + pricing = tracking.get_pricing_openrouter() assert isinstance(pricing, dict) assert all(isinstance(v, dict) for v in pricing.values()) for model in OPENROUTER_MODELS: @@ -46,7 +46,7 @@ def test_get_pricing_openrouter(): def test_get_pricing_openai(): - pricing = tracking.get_pricing(api="openai") + pricing = tracking.get_pricing_openai() assert isinstance(pricing, dict) assert all("prompt" in pricing[model] and "completion" in pricing[model] for model in pricing) assert all(isinstance(pricing[model]["prompt"], float) for model in pricing) From 17e8ff8dbcf2217b7871e09d89f5d03621ec9240 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 11:21:35 -0400 Subject: [PATCH 12/37] ciao retry_parallel --- src/agentlab/llm/llm_utils.py | 63 ----------------------------------- tests/llm/test_llm_utils.py | 34 ------------------- 2 files changed, 97 deletions(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 1a8d8b70..5c16ef0d 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -177,69 +177,6 @@ def retry_raise( raise RetryError(f"Could not parse a valid value after {n_retry} retries.") -def retry_parallel(chat: "BaseChatModel", messages, n_retry, parser): - """Retry querying the chat models with the response from the parser until it returns a valid value. - - It will stop after `n_retry`. It assuemes that chat will generate n_parallel answers for each message. - The best answer is selected according to the score returned by the parser. If no answer is valid, the - it will retry with the best answer so far and append to the chat the retry message. If there is a - single parallel generation, it behaves like retry. - - This function is, in principle, more robust than retry. The speed and cost overhead is minimal with - the prompt is large and the length of the generated message is small. - - Args: - chat (BaseChatModel): a langchain BaseChatModel taking a list of messages and - returning a list of answers. - messages (list): the list of messages so far. - n_retry (int): the maximum number of sequential retries. - parser (function): a function taking a message and returning a tuple - with the following fields: - value : the parsed value, - valid : a boolean indicating if the value is valid, - retry_message : a message to send to the chat if the value is not valid - - Returns: - dict: the parsed value, with a string at key "action". - - Raises: - ValueError: if the parser could not parse a valid value after n_retry retries. - BadRequestError: if the message is too long - """ - - for i in range(n_retry): - try: - answers = chat.generate([messages]).generations[0] # chat.n parallel completions - except BadRequestError as e: - # most likely, the added messages triggered a message too long error - # we thus retry without the last two messages - if i == 0: - raise e - msg = f"BadRequestError, most likely the message is too long retrying with previous query." - warn(msg) - messages = messages[:-2] - answers = chat.generate([messages]).generations[0] - - values, valids, retry_messages, scores = zip( - *[parser(answer.message.content) for answer in answers] - ) - idx = np.argmax(scores) - value = values[idx] - valid = valids[idx] - retry_message = retry_messages[idx] - answer = answers[idx].message - - if valid: - return value - - msg = f"Query failed. Retrying {i+1}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" - warn(msg) - messages.append(answer) # already of type AIMessage - messages.append(SystemMessage(content=retry_message)) - - raise ValueError(f"Could not parse a valid value after {n_retry} retries.") - - def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): """Use tiktoken to truncate a text to a maximum number of tokens.""" enc = tiktoken.encoding_for_model(model_name) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 1bdbdacb..41dca395 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -94,40 +94,6 @@ def test_compress_string(): assert compressed_text == expected_output -@pytest.mark.pricy -def test_retry_parallel(): - chat = AzureChatOpenAI( - model_name="gpt-35-turbo", azure_deployment="gpt-35-turbo", temperature=0.2, n=3 - ) - prompt = """List primes from 1 to 10.""" - messages = [ - SystemMessage(content=prompt), - ] - - global n_call - n_call = 0 - - def parser(message): - global n_call - n_call += 1 - - if n_call <= 3: # First 3 calls, just answer the new prompt - return ( - None, - False, - "I changed my mind. List primes up to 15. Just answer the json list, nothing else.", - 0, - ) - elif n_call == 5: - return "success", True, "", 10 - else: - return "bad", True, "", 1 + np.random.rand() - - value = llm_utils.retry_parallel(chat, messages, parser=parser, n_retry=2) - assert value == "success" - assert n_call == 6 # 2*3 calls - - # Mock ChatOpenAI class class MockChatOpenAI: def invoke(self, messages): From 57b391323642eb3bbd8786336f77a249719254c5 Mon Sep 17 00:00:00 2001 From: Thibault Le Sellier de Chezelles Date: Fri, 20 Sep 2024 15:02:51 -0400 Subject: [PATCH 13/37] london 1666 (removing all (most) traces of langchain) --- src/agentlab/agents/dynamic_prompting.py | 4 +- .../agents/generic_agent/generic_agent.py | 6 +- .../most_basic_agent/most_basic_agent.py | 6 +- src/agentlab/agents/utils.py | 23 -------- src/agentlab/analyze/agent_xray.py | 26 ++++----- src/agentlab/analyze/inspect_results.py | 4 +- src/agentlab/llm/chat_api.py | 15 ++--- src/agentlab/llm/langchain_utils.py | 28 ++-------- src/agentlab/llm/llm_configs.py | 1 + src/agentlab/llm/llm_utils.py | 33 ++++++----- src/agentlab/llm/tracking.py | 5 +- tests/llm/test_chat_api.py | 55 +++++++++---------- tests/llm/test_langchain_utils.py | 9 ++- tests/llm/test_llm_utils.py | 19 +++---- tests/llm/test_tracking.py | 24 +++----- 15 files changed, 97 insertions(+), 161 deletions(-) delete mode 100644 src/agentlab/agents/utils.py diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 91c3dd3f..a9073487 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -245,9 +245,7 @@ def fit_tokens( additional_prompts = [additional_prompts] for prompt in additional_prompts: - max_prompt_tokens -= ( - count_tokens(prompt, model=model_name) + 1 - ) # +1 accounts for LangChain token + max_prompt_tokens -= count_tokens(prompt, model=model_name) + 1 # +1 because why not ? for _ in range(max_iterations): prompt = shrinkable.prompt diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index eb3b0d87..82fb01b1 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -3,11 +3,9 @@ from warnings import warn from browsergym.experiments.agent import Agent -from langchain.schema import HumanMessage, SystemMessage from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs -from agentlab.agents.utils import openai_monitored_agent from agentlab.llm.chat_api import BaseModelArgs from agentlab.llm.llm_utils import RetryError, retry_raise from agentlab.llm.tracking import get_action_decorator @@ -99,8 +97,8 @@ def get_action(self, obs): # cause it to be too long chat_messages = [ - SystemMessage(content=system_prompt), - HumanMessage(content=prompt), + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, ] ans_dict = retry_raise( self.chat_llm, diff --git a/src/agentlab/agents/most_basic_agent/most_basic_agent.py b/src/agentlab/agents/most_basic_agent/most_basic_agent.py index 0c8d7023..1864c975 100644 --- a/src/agentlab/agents/most_basic_agent/most_basic_agent.py +++ b/src/agentlab/agents/most_basic_agent/most_basic_agent.py @@ -7,7 +7,6 @@ from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.agent import Agent, AgentInfo from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs -from langchain.schema import AIMessage, HumanMessage, SystemMessage from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry_raise @@ -82,7 +81,10 @@ def get_action(self, obs: Any) -> tuple[str, dict]: Provide a chain of thoughts reasoning to decompose the task into smaller steps. And execute only the next step. """ - messages = [SystemMessage(content=system_prompt), HumanMessage(content=prompt)] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] def parser(response: str) -> tuple[dict, bool, str]: blocks = extract_code_blocks(response) diff --git a/src/agentlab/agents/utils.py b/src/agentlab/agents/utils.py deleted file mode 100644 index f4c6f19c..00000000 --- a/src/agentlab/agents/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from langchain_community.callbacks import get_openai_callback - - -def openai_monitored_agent(get_action_func): - def wrapper(self, obs): - with get_openai_callback() as openai_cb: - action, agent_info = get_action_func(self, obs) - - stats = { - "openai_total_cost": openai_cb.total_cost, - "openai_total_tokens": openai_cb.total_tokens, - "openai_completion_tokens": openai_cb.completion_tokens, - "openai_prompt_tokens": openai_cb.prompt_tokens, - } - - if "stats" in agent_info: - agent_info["stats"].update(stats) - else: - agent_info["stats"] = stats - - return action, agent_info - - return wrapper diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index caa462eb..f7793325 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -11,8 +11,7 @@ import pandas as pd from attr import dataclass from browsergym.experiments.loop import ExpResult, StepInfo -from langchain.schema import BaseMessage -from langchain_openai import ChatOpenAI +from openai import OpenAI from PIL import Image from agentlab.analyze import inspect_results @@ -560,9 +559,7 @@ def update_chat_messages(): chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) messages = [] for i, m in enumerate(chat_messages): - if isinstance(m, BaseMessage): - m = m.content - elif isinstance(m, dict): + if isinstance(m, dict): m = m.get("content", "No Content") messages.append(f"""# Message {i}\n```\n{m}\n```\n\n""") return "\n".join(messages) @@ -628,11 +625,16 @@ def submit_action(input_text): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2]) - assert isinstance(chat_messages[1], BaseMessage), "Messages should be langchain messages" - - chat = ChatOpenAI(name="gpt-4o-mini") - chat_messages[1].content = input_text - result_text = chat(chat_messages).content + assert isinstance(chat_messages[1], dict), "Messages should be a dict" + assert chat_messages[1].get("role", None) == "user", "Second message should be user" + + client = OpenAI() + chat_messages[1].get("content") = input_text + completion = client.chat.completions.create( + model="gpt-4o-mini", + messages=chat_messages, + ) + result_text = completion.choices[0].message.content return result_text @@ -641,9 +643,7 @@ def update_prompt_tests(): agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) prompt = chat_messages[1] - if isinstance(prompt, BaseMessage): - prompt = prompt.content - elif isinstance(prompt, dict): + if isinstance(prompt, dict): prompt = prompt.get("content", "No Content") return prompt, prompt diff --git a/src/agentlab/analyze/inspect_results.py b/src/agentlab/analyze/inspect_results.py index d69b1656..c81752a7 100644 --- a/src/agentlab/analyze/inspect_results.py +++ b/src/agentlab/analyze/inspect_results.py @@ -291,9 +291,9 @@ def summarize_stats(sub_df): key_ = key.split(".")[1] op = key_.split("_")[0] if op == "cum": - record[key_] = sub_df[key].sum(skipna=True).round(3) + record[key_] = sub_df[key].sum(skipna=True).round(6) elif op == "max": - record[key_] = sub_df[key].max(skipna=True).round(3) + record[key_] = sub_df[key].max(skipna=True).round(6) else: raise ValueError(f"Unknown stats operation: {op}") return pd.Series(record) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 035838aa..eab3bd6b 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -4,25 +4,22 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from langchain.schema import AIMessage -from langchain_openai import AzureChatOpenAI, ChatOpenAI - from agentlab.llm.langchain_utils import ( ChatOpenRouter, HuggingFaceAPIChatModel, HuggingFaceURLChatModel, ) -from agentlab.llm.tracking import OpenAIChatModel, OpenRouterChatModel +from agentlab.llm.tracking import AzureChatModel, OpenAIChatModel, OpenRouterChatModel if TYPE_CHECKING: - from langchain_core.language_models.chat_models import BaseChatModel + from agentlab.llm.tracking import ChatModel class CheatMiniWoBLLM: """For unit-testing purposes only. It only work with miniwob.click-test task.""" def invoke(self, messages) -> str: - prompt = messages[-1].content + prompt = messages[-1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: @@ -36,7 +33,7 @@ def invoke(self, messages) -> str: {action} """ - return AIMessage(content=answer) + return {"role": "assistant", "content": answer} def __call__(self, messages) -> str: return self.invoke(messages) @@ -71,7 +68,7 @@ class BaseModelArgs(ABC): vision_support: bool = False @abstractmethod - def make_model(self) -> "BaseChatModel": + def make_model(self) -> "ChatModel": pass def prepare_server(self): @@ -127,7 +124,7 @@ class AzureModelArgs(BaseModelArgs): deployment_name: str = None def make_model(self): - return AzureChatOpenAI( + return AzureChatModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, diff --git a/src/agentlab/llm/langchain_utils.py b/src/agentlab/llm/langchain_utils.py index 8b20b1df..b3383fcb 100644 --- a/src/agentlab/llm/langchain_utils.py +++ b/src/agentlab/llm/langchain_utils.py @@ -1,47 +1,27 @@ """This module helps uniformizing the interface of different LLMs for the AgentLab platform. We wish there was already a uniform interface.""" +import json import logging import os import time +from contextlib import contextmanager from functools import partial from typing import Any, List, Optional +import requests from huggingface_hub import InferenceClient from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_community.chat_models import ChatOpenAI from langchain_community.llms import HuggingFaceHub, HuggingFacePipeline +from openai import OpenAI from pydantic import Field from transformers import AutoTokenizer, GPT2TokenizerFast, pipeline from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -class ChatOpenRouter(ChatOpenAI): - openai_api_base: str - openai_api_key: str - model_name: str - - def __init__( - self, - model_name: str, - openai_api_key: Optional[str] = None, - openai_api_base: str = "https://openrouter.ai/api/v1", - temperature: Optional[float] = 0.5, - max_tokens: Optional[int] = 100, - ): - openai_api_key = openai_api_key or os.getenv("OPENROUTER_API_KEY") - super().__init__( - openai_api_base=openai_api_base, - openai_api_key=openai_api_key, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - - class HFBaseChatModel(SimpleChatModel): """ Custom LLM Chatbot that can interface with HuggingFace models. diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 9fbb0ee7..268735fb 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -25,6 +25,7 @@ max_total_tokens=128_000, max_input_tokens=40_000, max_new_tokens=4000, + vision_support=True, ), "openai/gpt-4-1106-preview": OpenAIModelArgs( model_name="gpt-4-1106-preview", diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 5c16ef0d..403adf3b 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -13,13 +13,12 @@ import numpy as np import tiktoken import yaml -from langchain.schema import BaseMessage, HumanMessage, SystemMessage from openai import BadRequestError, RateLimitError from PIL import Image from transformers import AutoModel, AutoTokenizer if TYPE_CHECKING: - from langchain_core.language_models import BaseChatModel + from agentlab.llm.tracking import ChatModel def _extract_wait_time(error_message, min_retry_wait_time=60): @@ -35,8 +34,8 @@ class RetryError(ValueError): def retry( - chat: "BaseChatModel", - messages, + chat: "ChatModel", + messages: list[dict], n_retry, parser, log=True, @@ -53,9 +52,9 @@ def retry( and expensive. Args: - chat (BaseChatModel): a langchain BaseChatModel taking a list of messages and - returning a list of answers. - messages (list): the list of messages so far. + chat (ChatModel): a ChatModel object taking a list of messages and + returning a list of answers, all in OpenAI format. + messages (list): the list of messages so far, in OpenAI format. n_retry (int): the maximum number of sequential retries. parser (function): a function taking a message and returning a tuple with the following fields: @@ -94,22 +93,22 @@ def retry( messages.append(answer) - value, valid, retry_message = parser(answer.content) + value, valid, retry_message = parser(answer.get("content")) if valid: return value tries += 1 if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get("content")}\n[User]:\n{retry_message}" logging.info(msg) - messages.append(HumanMessage(content=retry_message)) + messages.append(dict(role="user", content=retry_message)) raise RetryError(f"Could not parse a valid value after {n_retry} retries.") def retry_raise( - chat: "BaseChatModel", - messages: list[BaseMessage], + chat: "ChatModel", + messages: list[dict], n_retry: int, parser: callable, log: bool = True, @@ -126,8 +125,8 @@ def retry_raise( and expensive. Args: - chat (BaseChatModel): a langchain BaseChatModel taking a list of messages and - returning a list of answers. + chat (ChatModel): a ChatModel object taking a list of messages and + returning a list of answers, all in OpenAI format. messages (list): the list of messages so far. This list will be modified with the new messages and the retry messages. n_retry (int): the maximum number of sequential retries. @@ -166,13 +165,13 @@ def retry_raise( messages.append(answer) # TODO: could we change this to not use inplace modifications ? try: - return parser(answer.content) + return parser(answer.get("content")) except ParseError as parsing_error: tries += 1 if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{str(parsing_error)}" + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get("content")}\n[User]:\n{str(parsing_error)}" logging.info(msg) - messages.append(HumanMessage(content=str(parsing_error))) + messages.append(dict(role="user", content=str(parsing_error))) raise RetryError(f"Could not parse a valid value after {n_retry} retries.") diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 8a530e73..7531dd8d 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -5,12 +5,9 @@ from contextlib import contextmanager import requests -from langchain.schema import AIMessage, BaseMessage from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS from openai import AzureOpenAI, OpenAI -from agentlab.llm.langchain_utils import _convert_messages_to_dict - TRACKER = threading.local() @@ -126,7 +123,7 @@ def __call__(self, messages: list[dict]) -> dict: output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost - if isinstance(TRACKER.instance, LLMTracker): + if hasattr(TRACKER, "instance") and isinstance(TRACKER.instance, LLMTracker): TRACKER.instance(input_tokens, output_tokens, cost) return dict(role="assistant", content=completion.choices[0].message.content) diff --git a/tests/llm/test_chat_api.py b/tests/llm/test_chat_api.py index cf62fe32..aac05e2c 100644 --- a/tests/llm/test_chat_api.py +++ b/tests/llm/test_chat_api.py @@ -1,11 +1,8 @@ import os import pytest -from langchain.schema import HumanMessage, SystemMessage from agentlab.llm.chat_api import AzureModelArgs, HuggingFaceModelArgs, OpenAIModelArgs -from agentlab.llm.llm_utils import download_and_save_model -from agentlab.llm.prompt_templates import STARCHAT_PROMPT_TEMPLATE # TODO(optimass): figure out a good model for all tests @@ -16,27 +13,27 @@ skip_tests = False -@pytest.mark.pricy -@pytest.mark.skipif(skip_tests, reason="Skipping on remote as HF token have limited usage") -def test_api_model_args_hf(): - model_name = "HuggingFaceH4/starchat-beta" +# @pytest.mark.pricy +# @pytest.mark.skipif(skip_tests, reason="Skipping on remote as HF token have limited usage") +# def test_api_model_args_hf(): +# model_name = "HuggingFaceH4/starchat-beta" - model_args = HuggingFaceModelArgs( - model_name=model_name, - max_total_tokens=8192, - max_input_tokens=8192 - 512, - max_new_tokens=512, - temperature=1e-1, - ) - model = model_args.make_model() +# model_args = HuggingFaceModelArgs( +# model_name=model_name, +# max_total_tokens=8192, +# max_input_tokens=8192 - 512, +# max_new_tokens=512, +# temperature=1e-1, +# ) +# model = model_args.make_model() - messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), - ] - answer = model.invoke(messages) +# messages = [ +# SystemMessage(content="You are an helpful virtual assistant"), +# HumanMessage(content="Give the third prime number"), +# ] +# answer = model.invoke(messages) - assert "5" in answer.content +# assert "5" in answer.content @pytest.mark.pricy @@ -53,19 +50,19 @@ def test_api_model_args_azure(): model = model_args.make_model() messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Give the third prime number"), ] answer = model.invoke(messages) - assert "5" in answer.content + assert "5" in answer.get("content") @pytest.mark.pricy -@pytest.mark.skip(reason="Skipping atm for lack of better marking") +@pytest.mark.skipif(skip_tests, reason="Skipping on remote as Azure is pricy") def test_api_model_args_openai(): model_args = OpenAIModelArgs( - model_name="gpt-3.5-turbo-0125", + model_name="gpt-4o-mini", max_total_tokens=8192, max_input_tokens=8192 - 512, max_new_tokens=512, @@ -74,9 +71,9 @@ def test_api_model_args_openai(): model = model_args.make_model() messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Give the third prime number"), ] answer = model.invoke(messages) - assert "5" in answer.content + assert "5" in answer.get("content") diff --git a/tests/llm/test_langchain_utils.py b/tests/llm/test_langchain_utils.py index e4de8a3d..5e2b7c6f 100644 --- a/tests/llm/test_langchain_utils.py +++ b/tests/llm/test_langchain_utils.py @@ -1,5 +1,4 @@ import pytest -from langchain.schema import HumanMessage, SystemMessage from agentlab.llm.chat_api import HuggingFaceAPIChatModel, HuggingFaceURLChatModel from agentlab.llm.llm_utils import download_and_save_model @@ -20,8 +19,8 @@ def test_CustomLLMChatbot_remotely(): ) messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Is python a programming language?"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Is python a programming language?"), ] answer = chatbot(messages) @@ -37,8 +36,8 @@ def test_CustomLLMChatbot_locally(): chatbot = HuggingFaceURLChatModel(model_path=model_path, temperature=1e-3) messages = [ - SystemMessage(content="Please tell me back the following word: "), - HumanMessage(content="bird"), + dict(role="system", content="Please tell me back the following word: "), + dict(role="user", content="bird"), ] answer = chatbot(messages) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 41dca395..a9edd7a8 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -1,13 +1,10 @@ import warnings from typing import Literal -from unittest import mock from unittest.mock import Mock import httpx import numpy as np import pytest -from langchain.schema import SystemMessage -from langchain_openai import AzureChatOpenAI from openai import RateLimitError from agentlab.llm import llm_utils @@ -152,7 +149,7 @@ def test_rate_limit_success(): mock_chat.invoke = Mock( side_effect=[ mock_rate_limit_error("Rate limit reached. Please try again in 2s."), - SystemMessage(content="correct content"), + dict(role="system", content="correct content"), ] ) @@ -177,9 +174,9 @@ def test_successful_parse_before_max_retries(): # content on the 3rd time mock_chat.invoke = Mock( side_effect=[ - SystemMessage(content="wrong content"), - SystemMessage(content="wrong content"), - SystemMessage(content="correct content"), + dict(role="system", content="wrong content"), + dict(role="system", content="wrong content"), + dict(role="system", content="correct content"), ] ) @@ -196,9 +193,9 @@ def test_unsuccessful_parse_before_max_retries(): # content on the 3rd time mock_chat.invoke = Mock( side_effect=[ - SystemMessage(content="wrong content"), - SystemMessage(content="wrong content"), - SystemMessage(content="correct content"), + dict(role="system", content="wrong content"), + dict(role="system", content="wrong content"), + dict(role="system", content="correct content"), ] ) with pytest.raises(ValueError): @@ -209,7 +206,7 @@ def test_unsuccessful_parse_before_max_retries(): def test_retry_parse_raises(): mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock(return_value=SystemMessage(content="mocked response")) + mock_chat.invoke = Mock(return_value=dict(role="system", content="mocked response")) parser_raises = Mock(side_effect=ValueError("Parser error")) with pytest.raises(ValueError): diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index 4672cacd..9e5a13ac 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -122,15 +122,13 @@ def test_openai_chat_model(): assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 - from langchain.schema import HumanMessage, SystemMessage - messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) - assert "5" in answer.content + assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -149,15 +147,13 @@ def test_azure_chat_model(): assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 - from langchain.schema import HumanMessage, SystemMessage - messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) - assert "5" in answer.content + assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -168,13 +164,11 @@ def test_openrouter_chat_model(): assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 - from langchain.schema import HumanMessage, SystemMessage - messages = [ - SystemMessage(content="You are an helpful virtual assistant"), - HumanMessage(content="Give the third prime number"), + dict(role="system", content="You are an helpful virtual assistant"), + dict(role="user", content="Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) - assert "5" in answer.content + assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 From 592c6762893446acbd94e40778cefa9417b5df02 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 14:24:46 -0400 Subject: [PATCH 14/37] renaming langchain_utils to huggingface_utils --- ...angchain_utils.py => huggingface_utils.py} | 200 +++++++----------- 1 file changed, 75 insertions(+), 125 deletions(-) rename src/agentlab/llm/{langchain_utils.py => huggingface_utils.py} (50%) diff --git a/src/agentlab/llm/langchain_utils.py b/src/agentlab/llm/huggingface_utils.py similarity index 50% rename from src/agentlab/llm/langchain_utils.py rename to src/agentlab/llm/huggingface_utils.py index b3383fcb..9672c738 100644 --- a/src/agentlab/llm/langchain_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,28 +1,17 @@ -"""This module helps uniformizing the interface of different LLMs for the -AgentLab platform. We wish there was already a uniform interface.""" - -import json import logging import os import time -from contextlib import contextmanager from functools import partial from typing import Any, List, Optional -import requests from huggingface_hub import InferenceClient -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import SimpleChatModel -from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_community.llms import HuggingFaceHub, HuggingFacePipeline -from openai import OpenAI from pydantic import Field -from transformers import AutoTokenizer, GPT2TokenizerFast, pipeline +from transformers import AutoTokenizer, GPT2TokenizerFast from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -class HFBaseChatModel(SimpleChatModel): +class HFBaseChatModel: """ Custom LLM Chatbot that can interface with HuggingFace models. @@ -62,28 +51,25 @@ def __init__(self, model_name, n_retry_server): self.tokenizer = None self.prompt_template = get_prompt_template(model_name) - def _call( + def __call__( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: + messages: list[dict], + ) -> dict: # NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation. if self.tokenizer: - messages_formated = _convert_messages_to_dict(messages) + # messages_formated = _convert_messages_to_dict(messages) ## ? try: - prompt = self.tokenizer.apply_chat_template(messages_formated, tokenize=False) + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) except Exception as e: if "Conversation roles must alternate" in str(e): logging.warning( f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role" "Retrying with the 'system' role appended to the 'user' role." ) - messages_formated = _prepend_system_to_first_user(messages_formated) - prompt = self.tokenizer.apply_chat_template(messages_formated, tokenize=False) + messages = _prepend_system_to_first_user(messages) + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) else: raise e @@ -108,44 +94,8 @@ def _call( def _llm_type(self): return "huggingface" - -class HuggingFaceAPIChatModel(HFBaseChatModel): - def __init__( - self, - model_name: str, - temperature: Optional[int] = 1e-1, - max_new_tokens: Optional[int] = 512, - n_retry_server: Optional[int] = 4, - ): - super().__init__(model_name, n_retry_server) - if temperature < 1e-3: - logging.warning("Models might behave weirdly when temperature is too low.") - self.llm = HuggingFaceHub( - repo_id=model_name, - model_kwargs={"temperature": temperature, "max_length": max_new_tokens}, - ) - - -class HuggingFaceLocalChatModel(HFBaseChatModel): - def __init__( - self, - model_name: str, - temperature: Optional[int] = 1e-1, - max_new_tokens: Optional[int] = 512, - n_retry_server: Optional[int] = 4, - ): - super().__init__(model_name, n_retry_server) - if temperature < 1e-3: - logging.warning("Models might behave weirdly when temperature is too low.") - self.llm = HuggingFacePipeline( - pipeline( - task="text-generation", - model=model_name, - device_map="auto", - max_new_tokens=max_new_tokens, - model_kwargs={"temperature": temperature}, - ) - ) + def invoke(self, messages: list[dict]) -> dict: + return self(messages) class HuggingFaceURLChatModel(HFBaseChatModel): @@ -171,70 +121,70 @@ def __init__( ) -def _convert_messages_to_dict(messages, column_remap={}): - """ - Converts a list of message objects into a list of dictionaries, categorizing each message by its role. - - Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. - The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. - - Args: - messages (list): A list of message objects. - column_remap (dict): A dictionary that maps the column names to the desired output format. - - Returns: - list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. - - Raises: - ValueError: If an unsupported message type is encountered. - - Example: - >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] - >>> _convert_messages_to_dict(messages) - [ - {"role": "system", "content": "System initializing..."}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "How can I assist?"} - ] - """ - - human_key = column_remap.get("HumanMessage", "user") - ai_message_key = column_remap.get("AIMessage", "assistant") - role_key = column_remap.get("role", "role") - text_key = column_remap.get("text", "content") - image_key = column_remap.get("image", "media_url") - - # Mapping of message types to roles - message_type_to_role = { - SystemMessage: "system", - HumanMessage: human_key, - AIMessage: ai_message_key, - } - - def convert_format_vision(message_content, role, text_key, image_key): - result = {} - result["type"] = role - for item in message_content: - if item["type"] == "text": - result[text_key] = item["text"] - elif item["type"] == "image_url": - result[image_key] = item["image_url"] - return result - - chat = [] - for message in messages: - message_role = message_type_to_role.get(type(message)) - if message_role: - if isinstance(message.content, str): - chat.append({role_key: message_role, text_key: message.content}) - else: - chat.append( - convert_format_vision(message.content, message_role, text_key, image_key) - ) - else: - raise ValueError(f"Message type {type(message)} not supported") - - return chat +# def _convert_messages_to_dict(messages, column_remap={}): +# """ +# Converts a list of message objects into a list of dictionaries, categorizing each message by its role. + +# Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. +# The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. + +# Args: +# messages (list): A list of message objects. +# column_remap (dict): A dictionary that maps the column names to the desired output format. + +# Returns: +# list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. + +# Raises: +# ValueError: If an unsupported message type is encountered. + +# Example: +# >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] +# >>> _convert_messages_to_dict(messages) +# [ +# {"role": "system", "content": "System initializing..."}, +# {"role": "user", "content": "Hello!"}, +# {"role": "assistant", "content": "How can I assist?"} +# ] +# """ + +# human_key = column_remap.get("HumanMessage", "user") +# ai_message_key = column_remap.get("AIMessage", "assistant") +# role_key = column_remap.get("role", "role") +# text_key = column_remap.get("text", "content") +# image_key = column_remap.get("image", "media_url") + +# # Mapping of message types to roles +# message_type_to_role = { +# SystemMessage: "system", +# HumanMessage: human_key, +# AIMessage: ai_message_key, +# } + +# def convert_format_vision(message_content, role, text_key, image_key): +# result = {} +# result["type"] = role +# for item in message_content: +# if item["type"] == "text": +# result[text_key] = item["text"] +# elif item["type"] == "image_url": +# result[image_key] = item["image_url"] +# return result + +# chat = [] +# for message in messages: +# message_role = message_type_to_role.get(type(message)) +# if message_role: +# if isinstance(message.content, str): +# chat.append({role_key: message_role, text_key: message.content}) +# else: +# chat.append( +# convert_format_vision(message.content, message_role, text_key, image_key) +# ) +# else: +# raise ValueError(f"Message type {type(message)} not supported") + +# return chat def _prepend_system_to_first_user(messages, column_remap={}): From bdab2fb77392611173b3d0a0f33caa261cf0b923 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 14:25:08 -0400 Subject: [PATCH 15/37] deps update --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9dce3fcc..9349804e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,6 @@ distributed browsergym>=0.7.0 joblib>=1.2.0 openai>=1.7,<2 -langchain>=0.1,<1 -langchain_openai langchain_community tiktoken huggingface_hub @@ -20,3 +18,4 @@ pyyaml>=6 pandas gradio gitpython # for the reproducibility script +requests \ No newline at end of file From 9708961d4e293d82504e1858cb32a982d2fd059b Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 14:26:27 -0400 Subject: [PATCH 16/37] removing last langchain traces --- src/agentlab/llm/README.md | 6 ----- src/agentlab/llm/chat_api.py | 38 +++++++--------------------- src/agentlab/llm/prompt_templates.py | 25 ++++++++---------- src/agentlab/llm/tracking.py | 3 --- tests/llm/test_llm_utils.py | 1 - 5 files changed, 20 insertions(+), 53 deletions(-) diff --git a/src/agentlab/llm/README.md b/src/agentlab/llm/README.md index 81c94550..dea60ddf 100644 --- a/src/agentlab/llm/README.md +++ b/src/agentlab/llm/README.md @@ -95,12 +95,6 @@ TODO - in their demo, they queried the SNOW UI! - -## Relevant agentic tools - -- [Langchain Agents](https://python.langchain.com/docs/modules/agents/) - - ## Relevant Benchmarks - [bigcode/bigcode-models-leaderboard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 112e611d..77b6866d 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -4,16 +4,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from agentlab.llm.langchain_utils import ( - ChatOpenRouter, - HuggingFaceAPIChatModel, - HuggingFaceURLChatModel, - _convert_messages_to_dict, -) -from agentlab.llm.tracking import AzureChatModel, OpenAIChatModel, OpenRouterChatModel +from openai import AzureOpenAI, OpenAI -if TYPE_CHECKING: - from agentlab.llm.tracking import ChatModel +import agentlab.llm.tracking as tracking +from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel class CheatMiniWoBLLM: @@ -105,19 +99,6 @@ def make_model(self): ) -@dataclass -class HuggingFaceModelArgs(BaseModelArgs): - """Serializable object for instantiating a generic chat model with a HuggingFace model.""" - - def make_model(self): - return HuggingFaceAPIChatModel( - model_name=self.model_name, - temperature=self.temperature, - max_new_tokens=self.max_new_tokens, - n_retry_server=4, - ) - - @dataclass class AzureModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an Azure model.""" @@ -203,16 +184,15 @@ def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100): self.temperature = temperature self.max_tokens = max_tokens - self.client = tracking.OpenAI() + self.client = OpenAI() self.input_cost = 0.0 self.output_cost = 0.0 def __call__(self, messages: list[dict]) -> dict: - messages_formatted = _convert_messages_to_dict(messages) completion = self.client.chat.completions.create( model=self.model_name, - messages=messages_formatted, + messages=messages, temperature=self.temperature, max_tokens=self.max_tokens, ) @@ -225,7 +205,7 @@ def __call__(self, messages: list[dict]) -> dict: ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) - return AIMessage(content=completion.choices[0].message.content) + return dict(role="assistant", content=completion.choices[0].message.content) def invoke(self, messages: list[dict]) -> dict: return self(messages) @@ -250,7 +230,7 @@ def __init__( self.input_cost = pricings[model_name]["prompt"] self.output_cost = pricings[model_name]["completion"] - self.client = tracking.OpenAI( + self.client = OpenAI( base_url="https://openrouter.ai/api/v1", api_key=api_key, ) @@ -275,7 +255,7 @@ def __init__( self.input_cost = float(pricings[model_name]["prompt"]) self.output_cost = float(pricings[model_name]["completion"]) - self.client = tracking.OpenAI( + self.client = OpenAI( api_key=api_key, ) @@ -304,7 +284,7 @@ def __init__( self.input_cost = float(pricings[model_name]["prompt"]) self.output_cost = float(pricings[model_name]["completion"]) - self.client = tracking.AzureOpenAI( + self.client = AzureOpenAI( api_key=api_key, azure_deployment=deployment_name, azure_endpoint=endpoint, diff --git a/src/agentlab/llm/prompt_templates.py b/src/agentlab/llm/prompt_templates.py index 821575fa..afcacdca 100644 --- a/src/agentlab/llm/prompt_templates.py +++ b/src/agentlab/llm/prompt_templates.py @@ -1,9 +1,6 @@ -import logging from dataclasses import dataclass from typing import List -from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage - """ To use this class, you should have the ``openai`` python package installed, and the environment variable ``OPENAI_API_KEY`` set with your API key. @@ -24,7 +21,7 @@ class PromptTemplate: ai: str prompt_end: str = "" - def format_message(self, message: BaseMessage) -> str: + def format_message(self, message: dict) -> str: """ Formats a given message based on its type. @@ -37,16 +34,16 @@ def format_message(self, message: BaseMessage) -> str: Raises: ValueError: If the message type is not supported. """ - if isinstance(message, SystemMessage): - return self.system.format(input=message.content) - elif isinstance(message, HumanMessage): - return self.human.format(input=message.content) - elif isinstance(message, AIMessage): - return self.ai.format(input=message.content) + if message["role"] == "system": + return self.system.format(input=message["content"]) + elif message["role"] == "user": + return self.human.format(input=message["content"]) + elif message["role"] == "assistant": + return self.ai.format(input=message["content"]) else: - raise ValueError(f"Message type {type(message)} not supported") + raise ValueError(f"Message role {message['role']} not supported") - def construct_prompt(self, messages: List[BaseMessage]) -> str: + def construct_prompt(self, messages: List[dict]) -> str: """ Constructs a prompt from a list of messages. @@ -59,8 +56,8 @@ def construct_prompt(self, messages: List[BaseMessage]) -> str: Raises: ValueError: If any element in the list is not of type BaseMessage. """ - if not all(isinstance(m, BaseMessage) for m in messages): - raise ValueError("All elements in the list must be of type BaseMessage") + if not all(isinstance(m, dict) and "role" in m and "content" in m for m in messages): + raise ValueError("All elements in the list must be in openai format") prompt = "".join([self.format_message(m) for m in messages]) prompt += self.prompt_end diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index b44944e1..7e2761ac 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,12 +1,9 @@ -import ast import os import threading -from abc import ABC, abstractmethod from contextlib import contextmanager import requests from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS -from openai import AzureOpenAI, OpenAI TRACKER = threading.local() diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index a9edd7a8..c8c59400 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -3,7 +3,6 @@ from unittest.mock import Mock import httpx -import numpy as np import pytest from openai import RateLimitError From 631057de47c5626c55745bbc6ab454bd3804d1c9 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 15:52:08 -0400 Subject: [PATCH 17/37] import typo --- src/agentlab/llm/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 403adf3b..0a6b1525 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -18,7 +18,7 @@ from transformers import AutoModel, AutoTokenizer if TYPE_CHECKING: - from agentlab.llm.tracking import ChatModel + from agentlab.llm.chat_api import ChatModel def _extract_wait_time(error_message, min_retry_wait_time=60): From 560f3e5fc8ddd2c600e84dc0ed06ce2b62cc94c6 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 15:57:08 -0400 Subject: [PATCH 18/37] adding retry functionality to ChatModel --- src/agentlab/llm/chat_api.py | 56 ++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 77b6866d..02482ca4 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -1,13 +1,17 @@ +import logging import os import re +import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING +import openai from openai import AzureOpenAI, OpenAI import agentlab.llm.tracking as tracking from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel +from agentlab.llm.llm_utils import _extract_wait_time class CheatMiniWoBLLM: @@ -179,10 +183,11 @@ def make_model(self): class ChatModel(ABC): @abstractmethod - def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100): + def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens + self.max_retry = max_retry self.client = OpenAI() @@ -190,12 +195,29 @@ def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100): self.output_cost = 0.0 def __call__(self, messages: list[dict]) -> dict: - completion = self.client.chat.completions.create( - model=self.model_name, - messages=messages, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) + completion = None + for itr in range(self.max_retry): + try: + completion = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + break + except openai.OpenAIError as e: + logging.warning( + f"Failed to get a response from the API: \n{e}\n" + f"Retrying... ({itr+1}/{self.max_retry})" + ) + wait_time = _extract_wait_time(e) + logging.info(f"Waiting for {wait_time} seconds") + time.sleep(wait_time) + # TODO: add total delay limit ? + + if not completion: + raise Exception("Failed to get a response from the API") + input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens cost = input_tokens * self.input_cost + output_tokens * self.output_cost @@ -212,16 +234,11 @@ def invoke(self, messages: list[dict]) -> dict: class OpenRouterChatModel(ChatModel): - def __init__( - self, - model_name, - api_key=None, - temperature=0.5, - max_tokens=100, - ): + def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens + self.max_retry = max_retry api_key = api_key or os.getenv("OPENROUTER_API_KEY") @@ -237,16 +254,11 @@ def __init__( class OpenAIChatModel(ChatModel): - def __init__( - self, - model_name, - api_key=None, - temperature=0.5, - max_tokens=100, - ): + def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens + self.max_retry = max_retry api_key = api_key or os.getenv("OPENAI_API_KEY") @@ -268,10 +280,12 @@ def __init__( deployment_name=None, temperature=0.5, max_tokens=100, + max_retry=1, ): self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens + self.max_retry = max_retry api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") From ea3eacf9698797a476d0b0c22b278e4120e2c7af Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 24 Sep 2024 16:11:26 -0400 Subject: [PATCH 19/37] fixing tests --- src/agentlab/llm/llm_configs.py | 8 ------ tests/agents/test_agent.py | 6 ++--- tests/llm/test_chat_api.py | 25 +------------------ ...ain_utils.py => test_huggingface_utils.py} | 23 +---------------- 4 files changed, 5 insertions(+), 57 deletions(-) rename tests/llm/{test_langchain_utils.py => test_huggingface_utils.py} (57%) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 268735fb..4c447b0a 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -2,7 +2,6 @@ from agentlab.llm.chat_api import ( AzureModelArgs, - HuggingFaceModelArgs, OpenAIModelArgs, OpenRouterModelArgs, SelfHostedModelArgs, @@ -66,13 +65,6 @@ max_input_tokens=7500, max_new_tokens=500, ), - "HuggingFaceH4/starchat-beta": HuggingFaceModelArgs( - model_name="HuggingFaceH4/starchat-beta", - max_total_tokens=8192, - max_input_tokens=8192 - 512, - max_new_tokens=512, - temperature=1e-1, - ), # ---------------- OSS LLMs ----------------# "meta-llama/Meta-Llama-3-70B-Instruct": SelfHostedModelArgs( model_name="meta-llama/Meta-Llama-3-70B-Instruct", diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index 3b1be95a..d3880219 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -9,7 +9,7 @@ from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs from agentlab.analyze import inspect_results from agentlab.experiments import launch_exp -from agentlab.llm.chat_api import AIMessage, BaseModelArgs, CheatMiniWoBLLMArgs +from agentlab.llm.chat_api import BaseModelArgs, CheatMiniWoBLLMArgs def test_generic_agent(): @@ -52,7 +52,7 @@ class CheatMiniWoBLLM_Retry: def invoke(self, messages) -> str: if self.retry_count < self.n_retry: self.retry_count += 1 - return AIMessage(content="I'm retrying") + return dict(role="assistant", content="I'm retrying") prompt = messages[1].content match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) @@ -68,7 +68,7 @@ def invoke(self, messages) -> str: {action} """ - return AIMessage(content=answer) + return dict(role="assistant", content=answer) def __call__(self, messages) -> str: return self.invoke(messages) diff --git a/tests/llm/test_chat_api.py b/tests/llm/test_chat_api.py index aac05e2c..2768c3aa 100644 --- a/tests/llm/test_chat_api.py +++ b/tests/llm/test_chat_api.py @@ -2,7 +2,7 @@ import pytest -from agentlab.llm.chat_api import AzureModelArgs, HuggingFaceModelArgs, OpenAIModelArgs +from agentlab.llm.chat_api import AzureModelArgs, OpenAIModelArgs # TODO(optimass): figure out a good model for all tests @@ -13,29 +13,6 @@ skip_tests = False -# @pytest.mark.pricy -# @pytest.mark.skipif(skip_tests, reason="Skipping on remote as HF token have limited usage") -# def test_api_model_args_hf(): -# model_name = "HuggingFaceH4/starchat-beta" - -# model_args = HuggingFaceModelArgs( -# model_name=model_name, -# max_total_tokens=8192, -# max_input_tokens=8192 - 512, -# max_new_tokens=512, -# temperature=1e-1, -# ) -# model = model_args.make_model() - -# messages = [ -# SystemMessage(content="You are an helpful virtual assistant"), -# HumanMessage(content="Give the third prime number"), -# ] -# answer = model.invoke(messages) - -# assert "5" in answer.content - - @pytest.mark.pricy @pytest.mark.skipif(skip_tests, reason="Skipping on remote as Azure is pricy") def test_api_model_args_azure(): diff --git a/tests/llm/test_langchain_utils.py b/tests/llm/test_huggingface_utils.py similarity index 57% rename from tests/llm/test_langchain_utils.py rename to tests/llm/test_huggingface_utils.py index 5e2b7c6f..eb83aaeb 100644 --- a/tests/llm/test_langchain_utils.py +++ b/tests/llm/test_huggingface_utils.py @@ -1,33 +1,12 @@ import pytest -from agentlab.llm.chat_api import HuggingFaceAPIChatModel, HuggingFaceURLChatModel +from agentlab.llm.chat_api import HuggingFaceURLChatModel from agentlab.llm.llm_utils import download_and_save_model from agentlab.llm.prompt_templates import STARCHAT_PROMPT_TEMPLATE # TODO(optimass): figure out a good model for all tests -@pytest.mark.skip(reason="We can quickly hit the free tier limit on HuggingFace Hub") -def test_CustomLLMChatbot_remotely(): - # model_path = "google/flan-t5-base" # remote model on HuggingFace Hub - model_path = "HuggingFaceH4/starchat-beta" # remote model on HuggingFace Hub - - chatbot = HuggingFaceAPIChatModel( - model_path=model_path, - prompt_template=STARCHAT_PROMPT_TEMPLATE, - temperature=1e-3, - ) - - messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Is python a programming language?"), - ] - - answer = chatbot(messages) - - print(answer.content) - - @pytest.mark.skip(reason="Requires a local model checkpoint") def test_CustomLLMChatbot_locally(): # model_path = "google/flan-t5-base" # remote model on HuggingFace Hub From 88797d3dfecb8834324e63a49cedaa52965b8a0d Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Wed, 25 Sep 2024 14:40:18 -0400 Subject: [PATCH 20/37] typos --- src/agentlab/llm/chat_api.py | 1 - tests/agents/test_agent.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 02482ca4..5478a131 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -4,7 +4,6 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING import openai from openai import AzureOpenAI, OpenAI diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index d3880219..f51eddb6 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -54,7 +54,7 @@ def invoke(self, messages) -> str: self.retry_count += 1 return dict(role="assistant", content="I'm retrying") - prompt = messages[1].content + prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: From fe0db06b04223c0e0f2a5256d2f33216f74579b6 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 10:32:05 -0400 Subject: [PATCH 21/37] formatting --- src/agentlab/llm/llm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 0a6b1525..ec48ee50 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -99,7 +99,7 @@ def retry( tries += 1 if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get("content")}\n[User]:\n{retry_message}" + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get('content')}\n[User]:\n{retry_message}" logging.info(msg) messages.append(dict(role="user", content=retry_message)) @@ -169,7 +169,7 @@ def retry_raise( except ParseError as parsing_error: tries += 1 if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get("content")}\n[User]:\n{str(parsing_error)}" + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get('content')}\n[User]:\n{str(parsing_error)}" logging.info(msg) messages.append(dict(role="user", content=str(parsing_error))) From 340734115fa2688162e5a44488e22171e6708383 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 10:36:28 -0400 Subject: [PATCH 22/37] fixing imports --- tests/llm/test_tracking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index e22bac30..1060f073 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -121,7 +121,7 @@ def thread_2(results=None): @pytest.mark.pricy @pytest.mark.skipif(not OPENAI_API_KEY_AVAILABLE, reason="OpenAI API key is not available") def test_openai_chat_model(): - chat_model = tracking.OpenAIChatModel("gpt-4o-mini") + chat_model = OpenAIChatModel("gpt-4o-mini") assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 @@ -146,7 +146,7 @@ def test_openai_chat_model(): not AZURE_OPENAI_API_KEY_AVAILABLE, reason="Azure OpenAI API key is not available" ) def test_azure_chat_model(): - chat_model = tracking.AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo") + chat_model = AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo") assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 @@ -163,7 +163,7 @@ def test_azure_chat_model(): @pytest.mark.pricy @pytest.mark.skipif(not OPENROUTER_API_KEY_AVAILABLE, reason="OpenRouter API key is not available") def test_openrouter_chat_model(): - chat_model = tracking.OpenRouterChatModel("openai/gpt-4o-mini") + chat_model = OpenRouterChatModel("openai/gpt-4o-mini") assert chat_model.input_cost > 0 assert chat_model.output_cost > 0 From 6a774077681ed98d4c255cb8818212859b4cee6f Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 11:38:18 -0400 Subject: [PATCH 23/37] retrocompat xray --- src/agentlab/analyze/agent_xray.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 51833972..c0845907 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -13,6 +13,7 @@ from browsergym.experiments.loop import ExpResult, StepInfo from openai import OpenAI from PIL import Image +from langchain.schema import BaseMessage from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR @@ -559,7 +560,9 @@ def update_chat_messages(): chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) messages = [] for i, m in enumerate(chat_messages): - if isinstance(m, dict): + if isinstance(m, BaseMessage): + m = m.content + elif isinstance(m, dict): m = m.get("content", "No Content") messages.append(f"""# Message {i}\n```\n{m}\n```\n\n""") return "\n".join(messages) From df6842ca40e5ab1ef26ea9a3829a1b458811f8dd Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 11:38:35 -0400 Subject: [PATCH 24/37] no rounding in stats --- src/agentlab/analyze/inspect_results.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/analyze/inspect_results.py b/src/agentlab/analyze/inspect_results.py index c81752a7..d10334d8 100644 --- a/src/agentlab/analyze/inspect_results.py +++ b/src/agentlab/analyze/inspect_results.py @@ -291,9 +291,9 @@ def summarize_stats(sub_df): key_ = key.split(".")[1] op = key_.split("_")[0] if op == "cum": - record[key_] = sub_df[key].sum(skipna=True).round(6) + record[key_] = sub_df[key].sum(skipna=True) elif op == "max": - record[key_] = sub_df[key].max(skipna=True).round(6) + record[key_] = sub_df[key].max(skipna=True) else: raise ValueError(f"Unknown stats operation: {op}") return pd.Series(record) From 04bb9150a3363806ff6f1cece84f60aa20bdb294 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 11:58:26 -0400 Subject: [PATCH 25/37] retrocompat xray --- src/agentlab/analyze/agent_xray.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index c0845907..571f9886 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -11,9 +11,9 @@ import pandas as pd from attr import dataclass from browsergym.experiments.loop import ExpResult, StepInfo +from langchain.schema import BaseMessage, HumanMessage from openai import OpenAI from PIL import Image -from langchain.schema import BaseMessage from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR @@ -627,11 +627,19 @@ def submit_action(input_text): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2]) - assert isinstance(chat_messages[1], dict), "Messages should be a dict" - assert chat_messages[1].get("role", None) == "user", "Second message should be user" + if isinstance(chat_messages[1], BaseMessage): + assert isinstance(chat_messages[1], HumanMessage), "Second message should be user" + chat_messages = [ + {"role": "system", "content": chat_messages[0].content}, + {"role": "user", "content": chat_messages[1].content}, + ] + elif isinstance(chat_messages[1], dict): + assert chat_messages[1].get("role", None) == "user", "Second message should be user" + else: + raise ValueError("Chat messages should be a list of BaseMessage or dict") client = OpenAI() - chat_messages[1].get("content") = input_text + chat_messages[1]["content"] = input_text completion = client.chat.completions.create( model="gpt-4o-mini", messages=chat_messages, From 42ab6502e820d622ffaa36a41931f870fc760de6 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 15:33:10 -0400 Subject: [PATCH 26/37] made message helper functions --- .../agents/generic_agent/generic_agent.py | 6 +++--- .../most_basic_agent/most_basic_agent.py | 5 +++-- src/agentlab/analyze/agent_xray.py | 5 +++-- src/agentlab/llm/chat_api.py | 16 +++++++++++++-- tests/agents/test_generic_prompt.py | 11 ++++------ tests/llm/test_chat_api.py | 15 +++++++++----- tests/llm/test_huggingface_utils.py | 6 +++--- tests/llm/test_llm_utils.py | 17 ++++++++-------- tests/llm/test_tracking.py | 20 ++++++++++++------- 9 files changed, 62 insertions(+), 39 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index b5e77d92..987d5c32 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -6,7 +6,7 @@ from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs -from agentlab.llm.chat_api import BaseModelArgs +from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message from agentlab.llm.llm_utils import RetryError, retry_raise from agentlab.llm.tracking import cost_tracker_decorator @@ -97,8 +97,8 @@ def get_action(self, obs): # cause it to be too long chat_messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, + make_system_message(system_prompt), + make_user_message(prompt), ] ans_dict = retry_raise( self.chat_llm, diff --git a/src/agentlab/agents/most_basic_agent/most_basic_agent.py b/src/agentlab/agents/most_basic_agent/most_basic_agent.py index b44fe72f..986e06b6 100644 --- a/src/agentlab/agents/most_basic_agent/most_basic_agent.py +++ b/src/agentlab/agents/most_basic_agent/most_basic_agent.py @@ -8,6 +8,7 @@ from browsergym.experiments.agent import Agent, AgentInfo from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs +from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry_raise from agentlab.llm.tracking import cost_tracker_decorator @@ -84,8 +85,8 @@ def get_action(self, obs: Any) -> tuple[str, dict]: """ messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, + make_system_message(system_prompt), + make_user_message(prompt), ] def parser(response: str) -> tuple[dict, bool, str]: diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 571f9886..e4ab82f6 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -17,6 +17,7 @@ from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR +from agentlab.llm.chat_api import make_system_message, make_user_message select_dir_instructions = "Select Experiment Directory" AGENT_NAME_KEY = "agent.agent_name" @@ -630,8 +631,8 @@ def submit_action(input_text): if isinstance(chat_messages[1], BaseMessage): assert isinstance(chat_messages[1], HumanMessage), "Second message should be user" chat_messages = [ - {"role": "system", "content": chat_messages[0].content}, - {"role": "user", "content": chat_messages[1].content}, + make_system_message(chat_messages[0].content), + make_user_message(chat_messages[1].content), ] elif isinstance(chat_messages[1], dict): assert chat_messages[1].get("role", None) == "user", "Second message should be user" diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 5478a131..bf0cb92b 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -13,6 +13,18 @@ from agentlab.llm.llm_utils import _extract_wait_time +def make_system_message(content: str) -> dict: + return dict(role="system", content=content) + + +def make_user_message(content: str) -> dict: + return dict(role="user", content=content) + + +def make_assistant_message(content: str) -> dict: + return dict(role="assistant", content=content) + + class CheatMiniWoBLLM: """For unit-testing purposes only. It only work with miniwob.click-test task.""" @@ -31,7 +43,7 @@ def invoke(self, messages) -> str: {action} """ - return {"role": "assistant", "content": answer} + return make_assistant_message(answer) def __call__(self, messages) -> str: return self.invoke(messages) @@ -226,7 +238,7 @@ def __call__(self, messages: list[dict]) -> dict: ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) - return dict(role="assistant", content=completion.choices[0].message.content) + return make_assistant_message(completion.choices[0].message.content) def invoke(self, messages: list[dict]) -> dict: return self(messages) diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index 171d441e..712bc4db 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -1,15 +1,12 @@ from copy import deepcopy -from agentlab.agents import dynamic_prompting as dp -from agentlab.agents.generic_agent.generic_agent_prompt import ( - MainPrompt, - GenericPromptFlags, -) -from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 + import pytest +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 +from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags, MainPrompt from agentlab.llm.llm_utils import count_tokens - html_template = """ diff --git a/tests/llm/test_chat_api.py b/tests/llm/test_chat_api.py index 2768c3aa..b49f3588 100644 --- a/tests/llm/test_chat_api.py +++ b/tests/llm/test_chat_api.py @@ -2,7 +2,12 @@ import pytest -from agentlab.llm.chat_api import AzureModelArgs, OpenAIModelArgs +from agentlab.llm.chat_api import ( + AzureModelArgs, + OpenAIModelArgs, + make_system_message, + make_user_message, +) # TODO(optimass): figure out a good model for all tests @@ -27,8 +32,8 @@ def test_api_model_args_azure(): model = model_args.make_model() messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Give the third prime number"), + make_system_message("You are an helpful virtual assistant"), + make_user_message("Give the third prime number"), ] answer = model.invoke(messages) @@ -48,8 +53,8 @@ def test_api_model_args_openai(): model = model_args.make_model() messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Give the third prime number"), + make_system_message("You are an helpful virtual assistant"), + make_user_message("Give the third prime number"), ] answer = model.invoke(messages) diff --git a/tests/llm/test_huggingface_utils.py b/tests/llm/test_huggingface_utils.py index eb83aaeb..43cf88de 100644 --- a/tests/llm/test_huggingface_utils.py +++ b/tests/llm/test_huggingface_utils.py @@ -1,6 +1,6 @@ import pytest -from agentlab.llm.chat_api import HuggingFaceURLChatModel +from agentlab.llm.chat_api import HuggingFaceURLChatModel, make_system_message, make_user_message from agentlab.llm.llm_utils import download_and_save_model from agentlab.llm.prompt_templates import STARCHAT_PROMPT_TEMPLATE @@ -15,8 +15,8 @@ def test_CustomLLMChatbot_locally(): chatbot = HuggingFaceURLChatModel(model_path=model_path, temperature=1e-3) messages = [ - dict(role="system", content="Please tell me back the following word: "), - dict(role="user", content="bird"), + make_system_message("Please tell me back the following word: "), + make_user_message("bird"), ] answer = chatbot(messages) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index c8c59400..63170476 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -7,6 +7,7 @@ from openai import RateLimitError from agentlab.llm import llm_utils +from agentlab.llm.chat_api import make_system_message yaml_str = """Analysis: This is the analysis @@ -148,7 +149,7 @@ def test_rate_limit_success(): mock_chat.invoke = Mock( side_effect=[ mock_rate_limit_error("Rate limit reached. Please try again in 2s."), - dict(role="system", content="correct content"), + make_system_message("correct content"), ] ) @@ -173,9 +174,9 @@ def test_successful_parse_before_max_retries(): # content on the 3rd time mock_chat.invoke = Mock( side_effect=[ - dict(role="system", content="wrong content"), - dict(role="system", content="wrong content"), - dict(role="system", content="correct content"), + make_system_message("wrong content"), + make_system_message("wrong content"), + make_system_message("correct content"), ] ) @@ -192,9 +193,9 @@ def test_unsuccessful_parse_before_max_retries(): # content on the 3rd time mock_chat.invoke = Mock( side_effect=[ - dict(role="system", content="wrong content"), - dict(role="system", content="wrong content"), - dict(role="system", content="correct content"), + make_system_message("wrong content"), + make_system_message("wrong content"), + make_system_message("correct content"), ] ) with pytest.raises(ValueError): @@ -205,7 +206,7 @@ def test_unsuccessful_parse_before_max_retries(): def test_retry_parse_raises(): mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock(return_value=dict(role="system", content="mocked response")) + mock_chat.invoke = Mock(return_value=make_system_message("mocked response")) parser_raises = Mock(side_effect=ValueError("Parser error")) with pytest.raises(ValueError): diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index 1060f073..cc5abd36 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -5,7 +5,13 @@ import pytest import agentlab.llm.tracking as tracking -from agentlab.llm.chat_api import AzureChatModel, OpenAIChatModel, OpenRouterChatModel +from agentlab.llm.chat_api import ( + AzureChatModel, + OpenAIChatModel, + OpenRouterChatModel, + make_system_message, + make_user_message, +) def test_get_action_decorator(): @@ -126,8 +132,8 @@ def test_openai_chat_model(): assert chat_model.output_cost > 0 messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Give the third prime number"), + make_system_message("You are an helpful virtual assistant"), + make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) @@ -151,8 +157,8 @@ def test_azure_chat_model(): assert chat_model.output_cost > 0 messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Give the third prime number"), + make_system_message("You are an helpful virtual assistant"), + make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) @@ -168,8 +174,8 @@ def test_openrouter_chat_model(): assert chat_model.output_cost > 0 messages = [ - dict(role="system", content="You are an helpful virtual assistant"), - dict(role="user", content="Give the third prime number"), + make_system_message("You are an helpful virtual assistant"), + make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: answer = chat_model.invoke(messages) From 67f36ad0c697c11e50699db14b9471e7e114a3c6 Mon Sep 17 00:00:00 2001 From: Thibault LSDC <78021491+ThibaultLSDC@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:41:18 -0400 Subject: [PATCH 27/37] Update src/agentlab/llm/chat_api.py Co-authored-by: Alexandre Lacoste --- src/agentlab/llm/chat_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 5478a131..7f253e1b 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -17,7 +17,7 @@ class CheatMiniWoBLLM: """For unit-testing purposes only. It only work with miniwob.click-test task.""" def invoke(self, messages) -> str: - prompt = messages[-1].get("content", "") + prompt = messages[-1]["content"] match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) if match: From 307f4f2515aff54d5c0d6db85d4842a2cbe3bd04 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 15:44:18 -0400 Subject: [PATCH 28/37] doc --- src/agentlab/analyze/agent_xray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index e4ab82f6..1bc1877b 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -561,7 +561,7 @@ def update_chat_messages(): chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) messages = [] for i, m in enumerate(chat_messages): - if isinstance(m, BaseMessage): + if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated m = m.content elif isinstance(m, dict): m = m.get("content", "No Content") @@ -628,7 +628,7 @@ def submit_action(input_text): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2]) - if isinstance(chat_messages[1], BaseMessage): + if isinstance(chat_messages[1], BaseMessage): # TODO remove once langchain is deprecated assert isinstance(chat_messages[1], HumanMessage), "Second message should be user" chat_messages = [ make_system_message(chat_messages[0].content), From cf8720b6b16a348ffd31e9106ba886048a8290ba Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 15:48:39 -0400 Subject: [PATCH 29/37] specific retry exception --- src/agentlab/llm/chat_api.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index bf0cb92b..b3874f26 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -191,14 +191,29 @@ def make_model(self): pass +class RetryError(Exception): + pass + + class ChatModel(ABC): @abstractmethod - def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + max_retry=1, + min_retry_wait_time=60, + ): + assert max_retry > 0, "max_retry should be greater than 0" + self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens self.max_retry = max_retry + self.min_retry_wait_time = min_retry_wait_time self.client = OpenAI() @@ -221,13 +236,19 @@ def __call__(self, messages: list[dict]) -> dict: f"Failed to get a response from the API: \n{e}\n" f"Retrying... ({itr+1}/{self.max_retry})" ) - wait_time = _extract_wait_time(e) + wait_time = _extract_wait_time( + e.args[0], + min_retry_wait_time=self.min_retry_wait_time, + ) logging.info(f"Waiting for {wait_time} seconds") time.sleep(wait_time) # TODO: add total delay limit ? if not completion: - raise Exception("Failed to get a response from the API") + raise RetryError( + f"Failed to get a response from the API after {self.max_retry} retries\n\ +Last error: {e}" + ) input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens From 42515dc6b3a578953307017ddbc0bddcf939810f Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 15:56:37 -0400 Subject: [PATCH 30/37] bye retry --- src/agentlab/llm/llm_utils.py | 77 +---------------------------------- 1 file changed, 2 insertions(+), 75 deletions(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index ec48ee50..b90a2e0e 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -33,79 +33,6 @@ class RetryError(ValueError): pass -def retry( - chat: "ChatModel", - messages: list[dict], - n_retry, - parser, - log=True, - min_retry_wait_time=60, - rate_limit_max_wait_time=60 * 30, -): - """Retry querying the chat models with the response from the parser until it - returns a valid value. - - If the answer is not valid, it will retry and append to the chat the retry - message. It will stop after `n_retry`. - - Note, each retry has to resend the whole prompt to the API. This can be slow - and expensive. - - Args: - chat (ChatModel): a ChatModel object taking a list of messages and - returning a list of answers, all in OpenAI format. - messages (list): the list of messages so far, in OpenAI format. - n_retry (int): the maximum number of sequential retries. - parser (function): a function taking a message and returning a tuple - with the following fields: - value : the parsed value, - valid : a boolean indicating if the value is valid, - retry_message : a message to send to the chat if the value is not valid - log (bool): whether to log the retry messages. - min_retry_wait_time (float): the minimum wait time in seconds - after RateLimtError. will try to parse the wait time from the error - message. - rate_limit_max_wait_time (int): the maximum wait time in seconds - - Returns: - dict: the parsed value, with a string at key "action". - - Raises: - RetryError: if the parser could not parse a valid value after n_retry retries. - RateLimitError: if the requests exceed the rate limit. - """ - tries = 0 - rate_limit_total_delay = 0 - while tries < n_retry and rate_limit_total_delay < rate_limit_max_wait_time: - try: - answer = chat.invoke(messages) - except RateLimitError as e: - wait_time = _extract_wait_time(e.args[0], min_retry_wait_time) - logging.warning(f"RateLimitError, waiting {wait_time}s before retrying.") - time.sleep(wait_time) - rate_limit_total_delay += wait_time - if rate_limit_total_delay >= rate_limit_max_wait_time: - logging.warning( - f"Total wait time for rate limit exceeded. Waited {rate_limit_total_delay}s > {rate_limit_max_wait_time}s." - ) - raise - continue - - messages.append(answer) - - value, valid, retry_message = parser(answer.get("content")) - if valid: - return value - - tries += 1 - if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get('content')}\n[User]:\n{retry_message}" - logging.info(msg) - messages.append(dict(role="user", content=retry_message)) - - raise RetryError(f"Could not parse a valid value after {n_retry} retries.") - - def retry_raise( chat: "ChatModel", messages: list[dict], @@ -165,11 +92,11 @@ def retry_raise( messages.append(answer) # TODO: could we change this to not use inplace modifications ? try: - return parser(answer.get("content")) + return parser(answer["content"]) except ParseError as parsing_error: tries += 1 if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.get('content')}\n[User]:\n{str(parsing_error)}" + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(parsing_error)}" logging.info(msg) messages.append(dict(role="user", content=str(parsing_error))) From 5dee506ae126e65894caee3eeb471e68aaeb0955 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 26 Sep 2024 16:44:52 -0400 Subject: [PATCH 31/37] welcome back retry --- src/agentlab/agents/generic_agent/generic_agent.py | 4 ++-- src/agentlab/agents/most_basic_agent/most_basic_agent.py | 4 ++-- src/agentlab/llm/llm_utils.py | 2 +- tests/llm/test_llm_utils.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 987d5c32..1b781e1c 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -7,7 +7,7 @@ from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message -from agentlab.llm.llm_utils import RetryError, retry_raise +from agentlab.llm.llm_utils import RetryError, retry from agentlab.llm.tracking import cost_tracker_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt @@ -100,7 +100,7 @@ def get_action(self, obs): make_system_message(system_prompt), make_user_message(prompt), ] - ans_dict = retry_raise( + ans_dict = retry( self.chat_llm, chat_messages, n_retry=self.max_retry, diff --git a/src/agentlab/agents/most_basic_agent/most_basic_agent.py b/src/agentlab/agents/most_basic_agent/most_basic_agent.py index 986e06b6..210e726f 100644 --- a/src/agentlab/agents/most_basic_agent/most_basic_agent.py +++ b/src/agentlab/agents/most_basic_agent/most_basic_agent.py @@ -10,7 +10,7 @@ from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT -from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry_raise +from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry from agentlab.llm.tracking import cost_tracker_decorator if TYPE_CHECKING: @@ -97,7 +97,7 @@ def parser(response: str) -> tuple[dict, bool, str]: thought = response return {"action": action, "think": thought} - ans_dict = retry_raise(self.chat, messages, n_retry=3, parser=parser) + ans_dict = retry(self.chat, messages, n_retry=3, parser=parser) action = ans_dict.get("action", None) thought = ans_dict.get("think", None) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index b90a2e0e..5da50f49 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -33,7 +33,7 @@ class RetryError(ValueError): pass -def retry_raise( +def retry( chat: "ChatModel", messages: list[dict], n_retry: int, diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 63170476..ae74bd02 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -99,9 +99,9 @@ def invoke(self, messages): def mock_parser(answer): if answer == "correct content": - return "Parsed value", True, "" - - return None, False, "Retry message" + return "Parsed value" + else: + raise llm_utils.ParseError("Retry message") def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> RateLimitError: From 7e0adbcf2e485c603bbef8b5359e46ba36fbc19f Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 27 Sep 2024 11:18:42 -0400 Subject: [PATCH 32/37] moving API errors to ChatModel, restructuring ChatModels --- .../agents/generic_agent/generic_agent.py | 9 +- src/agentlab/llm/chat_api.py | 192 ++++++++++++------ src/agentlab/llm/llm_utils.py | 30 +-- 3 files changed, 130 insertions(+), 101 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 1b781e1c..e0ff5b6b 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -91,7 +91,6 @@ def get_action(self, obs): additional_prompts=system_prompt, ) - stats = {} try: # TODO, we would need to further shrink the prompt if the retry # cause it to be too long @@ -106,14 +105,8 @@ def get_action(self, obs): n_retry=self.max_retry, parser=main_prompt._parse_answer, ) - # inferring the number of retries, TODO: make this less hacky - stats["n_retry"] = (len(chat_messages) - 3) / 2 - stats["busted_retry"] = 0 except RetryError as e: ans_dict = {"action": None} - stats["busted_retry"] = 1 - - stats["n_retry"] = self.max_retry + 1 self.plan = ans_dict.get("plan", self.plan) self.plan_step = ans_dict.get("step", self.plan_step) @@ -121,6 +114,8 @@ def get_action(self, obs): self.memories.append(ans_dict.get("memory", None)) self.thoughts.append(ans_dict.get("think", None)) + stats = self.chat_llm.get_stats() + agent_info = dict( think=ans_dict.get("think", None), chat_messages=chat_messages, diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index b3874f26..f651588d 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -10,7 +10,6 @@ import agentlab.llm.tracking as tracking from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel -from agentlab.llm.llm_utils import _extract_wait_time def make_system_message(content: str) -> dict: @@ -48,6 +47,9 @@ def invoke(self, messages) -> str: def __call__(self, messages) -> str: return self.invoke(messages) + def get_stats(self): + return {} + @dataclass class CheatMiniWoBLLMArgs: @@ -191,21 +193,47 @@ def make_model(self): pass +def _extract_wait_time(error_message, min_retry_wait_time=60): + """Extract the wait time from an OpenAI RateLimitError message.""" + match = re.search(r"try again in (\d+(\.\d+)?)s", error_message) + if match: + return max(min_retry_wait_time, float(match.group(1))) + return min_retry_wait_time + + class RetryError(Exception): pass -class ChatModel(ABC): +def handle_error(error, itr, min_retry_wait_time, max_retry): + if not isinstance(error, openai.OpenAIError): + raise error + logging.warning( + f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})" + ) + wait_time = _extract_wait_time( + error.args[0], + min_retry_wait_time=min_retry_wait_time, + ) + logging.info(f"Waiting for {wait_time} seconds") + time.sleep(wait_time) + error_type = error.args[0] + return error_type - @abstractmethod + +class ChatModel: def __init__( self, model_name, api_key=None, temperature=0.5, max_tokens=100, - max_retry=1, + max_retry=4, min_retry_wait_time=60, + api_key_env_var=None, + client_class=OpenAI, + client_args=None, + pricing_func=None, ): assert max_retry > 0, "max_retry should be greater than 0" @@ -215,14 +243,36 @@ def __init__( self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time - self.client = OpenAI() + # Get the API key from the environment variable if not provided + if api_key_env_var: + api_key = api_key or os.getenv(api_key_env_var) + self.api_key = api_key + + # Get pricing information + if pricing_func: + pricings = pricing_func() + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + else: + self.input_cost = 0.0 + self.output_cost = 0.0 - self.input_cost = 0.0 - self.output_cost = 0.0 + client_args = client_args or {} + self.client = client_class( + api_key=api_key, + **client_args, + ) def __call__(self, messages: list[dict]) -> dict: + # Initialize retry tracking attributes + self.retries = 0 + self.success = False + self.error_types = [] + completion = None + e = None for itr in range(self.max_retry): + self.retries += 1 try: completion = self.client.chat.completions.create( model=self.model_name, @@ -230,24 +280,16 @@ def __call__(self, messages: list[dict]) -> dict: temperature=self.temperature, max_tokens=self.max_tokens, ) + self.success = True break except openai.OpenAIError as e: - logging.warning( - f"Failed to get a response from the API: \n{e}\n" - f"Retrying... ({itr+1}/{self.max_retry})" - ) - wait_time = _extract_wait_time( - e.args[0], - min_retry_wait_time=self.min_retry_wait_time, - ) - logging.info(f"Waiting for {wait_time} seconds") - time.sleep(wait_time) - # TODO: add total delay limit ? + error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry) + self.error_types.append(error_type) if not completion: raise RetryError( - f"Failed to get a response from the API after {self.max_retry} retries\n\ -Last error: {e}" + f"Failed to get a response from the API after {self.max_retry} retries\n" + f"Last error: {e}" ) input_tokens = completion.usage.prompt_tokens @@ -264,43 +306,60 @@ def __call__(self, messages: list[dict]) -> dict: def invoke(self, messages: list[dict]) -> dict: return self(messages) + def get_stats(self): + return { + "n_retry": self.retries, + "busted_retry": int(not self.success), + } -class OpenRouterChatModel(ChatModel): - def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): - self.model_name = model_name - self.temperature = temperature - self.max_tokens = max_tokens - self.max_retry = max_retry - - api_key = api_key or os.getenv("OPENROUTER_API_KEY") - - pricings = tracking.get_pricing_openrouter() - - self.input_cost = pricings[model_name]["prompt"] - self.output_cost = pricings[model_name]["completion"] - self.client = OpenAI( - base_url="https://openrouter.ai/api/v1", +class OpenAIChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + max_retry=4, + min_retry_wait_time=60, + ): + super().__init__( + model_name=model_name, api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + max_retry=max_retry, + min_retry_wait_time=min_retry_wait_time, + api_key_env_var="OPENAI_API_KEY", + client_class=OpenAI, + pricing_func=tracking.get_pricing_openai, ) -class OpenAIChatModel(ChatModel): - def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100, max_retry=1): - self.model_name = model_name - self.temperature = temperature - self.max_tokens = max_tokens - self.max_retry = max_retry - - api_key = api_key or os.getenv("OPENAI_API_KEY") - - pricings = tracking.get_pricing_openai() - - self.input_cost = float(pricings[model_name]["prompt"]) - self.output_cost = float(pricings[model_name]["completion"]) - - self.client = OpenAI( +class OpenRouterChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + max_retry=4, + min_retry_wait_time=60, + ): + client_args = { + "base_url": "https://openrouter.ai/api/v1", + } + super().__init__( + model_name=model_name, api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + max_retry=max_retry, + min_retry_wait_time=min_retry_wait_time, + api_key_env_var="OPENROUTER_API_KEY", + client_class=OpenAI, + client_args=client_args, + pricing_func=tracking.get_pricing_openrouter, ) @@ -312,27 +371,26 @@ def __init__( deployment_name=None, temperature=0.5, max_tokens=100, - max_retry=1, + max_retry=4, + min_retry_wait_time=60, ): - self.model_name = model_name - self.temperature = temperature - self.max_tokens = max_tokens - self.max_retry = max_retry - api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") - - # AZURE_OPENAI_ENDPOINT has to be defined in the environment endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment" - pricings = tracking.get_pricing_openai() - - self.input_cost = float(pricings[model_name]["prompt"]) - self.output_cost = float(pricings[model_name]["completion"]) - - self.client = AzureOpenAI( + client_args = { + "azure_deployment": deployment_name, + "azure_endpoint": endpoint, + "api_version": "2024-02-01", + } + super().__init__( + model_name=model_name, api_key=api_key, - azure_deployment=deployment_name, - azure_endpoint=endpoint, - api_version="2024-02-01", + temperature=temperature, + max_tokens=max_tokens, + max_retry=max_retry, + min_retry_wait_time=min_retry_wait_time, + client_class=AzureOpenAI, + client_args=client_args, + pricing_func=tracking.get_pricing_openai, ) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 5da50f49..afbcb27d 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -21,14 +21,6 @@ from agentlab.llm.chat_api import ChatModel -def _extract_wait_time(error_message, min_retry_wait_time=60): - """Extract the wait time from an OpenAI RateLimitError message.""" - match = re.search(r"try again in (\d+(\.\d+)?)s", error_message) - if match: - return max(min_retry_wait_time, float(match.group(1))) - return min_retry_wait_time - - class RetryError(ValueError): pass @@ -39,8 +31,6 @@ def retry( n_retry: int, parser: callable, log: bool = True, - min_retry_wait_time: int = 60, - rate_limit_max_wait_time: int = 60 * 30, ): """Retry querying the chat models with the response from the parser until it returns a valid value. @@ -73,22 +63,8 @@ def retry( RateLimitError: if the requests exceed the rate limit. """ tries = 0 - rate_limit_total_delay = 0 - while tries < n_retry and rate_limit_total_delay < rate_limit_max_wait_time: - try: - answer = chat.invoke(messages) - except RateLimitError as e: - wait_time = _extract_wait_time(e.args[0], min_retry_wait_time) - logging.warning(f"RateLimitError, waiting {wait_time}s before retrying.") - time.sleep(wait_time) - rate_limit_total_delay += wait_time - if rate_limit_total_delay >= rate_limit_max_wait_time: - logging.warning( - f"Total wait time for rate limit exceeded. Waited {rate_limit_total_delay}s > {rate_limit_max_wait_time}s." - ) - raise - continue - + while tries < n_retry: + answer = chat.invoke(messages) messages.append(answer) # TODO: could we change this to not use inplace modifications ? try: @@ -100,7 +76,7 @@ def retry( logging.info(msg) messages.append(dict(role="user", content=str(parsing_error))) - raise RetryError(f"Could not parse a valid value after {n_retry} retries.") + raise ParseError(f"Could not parse a valid value after {n_retry} retries.") def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): From 5bd7be6c8b617ed5d392814f48befe58cb24bfe2 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 27 Sep 2024 16:26:07 -0400 Subject: [PATCH 33/37] fix test --- .../agents/generic_agent/generic_agent.py | 20 +++-- src/agentlab/llm/chat_api.py | 4 +- tests/agents/test_agent.py | 3 + tests/llm/test_llm_utils.py | 84 +++++++++---------- 4 files changed, 61 insertions(+), 50 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index 218dded1..20d1cc10 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -7,7 +7,7 @@ from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message -from agentlab.llm.llm_utils import RetryError, retry +from agentlab.llm.llm_utils import ParseError, retry from agentlab.llm.tracking import cost_tracker_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt @@ -90,7 +90,6 @@ def get_action(self, obs): max_iterations=max_trunc_itr, additional_prompts=system_prompt, ) - try: # TODO, we would need to further shrink the prompt if the retry # cause it to be too long @@ -105,8 +104,19 @@ def get_action(self, obs): n_retry=self.max_retry, parser=main_prompt._parse_answer, ) - except RetryError as e: - ans_dict = {"action": None} + ans_dict["busted_retry"] = 0 + # inferring the number of retries, TODO: make this less hacky + ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + except ParseError as e: + ans_dict = dict( + action=None, + n_retry=self.max_retry + 1, + busted_retry=1, + ) + + stats = self.chat_llm.get_stats() + stats["n_retry"] = ans_dict["n_retry"] + stats["busted_retry"] = ans_dict["busted_retry"] self.plan = ans_dict.get("plan", self.plan) self.plan_step = ans_dict.get("step", self.plan_step) @@ -114,8 +124,6 @@ def get_action(self, obs): self.memories.append(ans_dict.get("memory", None)) self.thoughts.append(ans_dict.get("think", None)) - stats = self.chat_llm.get_stats() - agent_info = AgentInfo( think=ans_dict.get("think", None), chat_messages=chat_messages, diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 3f8b7e54..e910cd08 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -308,8 +308,8 @@ def invoke(self, messages: list[dict]) -> dict: def get_stats(self): return { - "n_retry": self.retries, - "busted_retry": int(not self.success), + "n_retry_llm": self.retries, + "busted_retry_llm": int(not self.success), } diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index e1d98bee..dfb7923f 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -73,6 +73,9 @@ def invoke(self, messages) -> str: def __call__(self, messages) -> str: return self.invoke(messages) + def get_stats(self): + return {"n_retry": self.n_retry, "busted_retry": self.retry_count} + @dataclass class CheatMiniWoBLLMArgs_Retry(BaseModelArgs): diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index ae74bd02..1314bea0 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -124,46 +124,46 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # Test to ensure function stops retrying after reaching the max wait time -def test_rate_limit_max_wait_time(): - mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock( - side_effect=mock_rate_limit_error("Rate limit reached. Please try again in 2s.") - ) - - with pytest.raises(RateLimitError): - llm_utils.retry( - mock_chat, - [], - n_retry=4, - parser=mock_parser, - rate_limit_max_wait_time=6, - min_retry_wait_time=1, - ) - - # The function should stop retrying after 2 attempts (6s each time, 12s total which is greater than the 10s max wait time) - assert mock_chat.invoke.call_count == 3 - - -def test_rate_limit_success(): - mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock( - side_effect=[ - mock_rate_limit_error("Rate limit reached. Please try again in 2s."), - make_system_message("correct content"), - ] - ) - - result = llm_utils.retry( - mock_chat, - [], - n_retry=4, - parser=mock_parser, - rate_limit_max_wait_time=6, - min_retry_wait_time=1, - ) - - assert result == "Parsed value" - assert mock_chat.invoke.call_count == 2 +# def test_rate_limit_max_wait_time(): +# mock_chat = MockChatOpenAI() +# mock_chat.invoke = Mock( +# side_effect=mock_rate_limit_error("Rate limit reached. Please try again in 2s.") +# ) + +# with pytest.raises(RateLimitError): +# llm_utils.retry( +# mock_chat, +# [], +# n_retry=4, +# parser=mock_parser, +# rate_limit_max_wait_time=6, +# min_retry_wait_time=1, +# ) + +# # The function should stop retrying after 2 attempts (6s each time, 12s total which is greater than the 10s max wait time) +# assert mock_chat.invoke.call_count == 3 + + +# def test_rate_limit_success(): +# mock_chat = MockChatOpenAI() +# mock_chat.invoke = Mock( +# side_effect=[ +# mock_rate_limit_error("Rate limit reached. Please try again in 2s."), +# make_system_message("correct content"), +# ] +# ) + +# result = llm_utils.retry( +# mock_chat, +# [], +# n_retry=4, +# parser=mock_parser, +# rate_limit_max_wait_time=6, +# min_retry_wait_time=1, +# ) + +# assert result == "Parsed value" +# assert mock_chat.invoke.call_count == 2 # Mock a successful parser response to test function exit before max retries @@ -180,7 +180,7 @@ def test_successful_parse_before_max_retries(): ] ) - result = llm_utils.retry(mock_chat, [], 5, mock_parser, min_retry_wait_time=1) + result = llm_utils.retry(mock_chat, [], 5, mock_parser) assert result == "Parsed value" assert mock_chat.invoke.call_count == 3 @@ -198,7 +198,7 @@ def test_unsuccessful_parse_before_max_retries(): make_system_message("correct content"), ] ) - with pytest.raises(ValueError): + with pytest.raises(llm_utils.ParseError): result = llm_utils.retry(mock_chat, [], 2, mock_parser) assert mock_chat.invoke.call_count == 2 From 1bf39433ac588019028e8c142789939deee52341 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 27 Sep 2024 17:09:23 -0400 Subject: [PATCH 34/37] fix error handling --- src/agentlab/llm/chat_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index e910cd08..6841f1ba 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -289,7 +289,7 @@ def __call__(self, messages: list[dict]) -> dict: if not completion: raise RetryError( f"Failed to get a response from the API after {self.max_retry} retries\n" - f"Last error: {e}" + f"Last error: {error_type}" ) input_tokens = completion.usage.prompt_tokens From b00a15dc4f9630299345e306535f13eed1954465 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Mon, 30 Sep 2024 10:18:48 -0400 Subject: [PATCH 35/37] updating hf llm class --- src/agentlab/llm/huggingface_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 9672c738..ce4dae06 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -97,6 +97,9 @@ def _llm_type(self): def invoke(self, messages: list[dict]) -> dict: return self(messages) + def get_stats(self): + return {} + class HuggingFaceURLChatModel(HFBaseChatModel): def __init__( From 80308a0403c6954a4bcf8fed1cd44c3517d1d674 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 1 Oct 2024 10:57:24 -0400 Subject: [PATCH 36/37] testing Retry/Parse Error behaviors --- src/agentlab/llm/chat_api.py | 2 +- tests/agents/test_agent.py | 116 ++++++++++++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 11 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 6841f1ba..87076067 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -309,7 +309,7 @@ def invoke(self, messages: list[dict]) -> dict: def get_stats(self): return { "n_retry_llm": self.retries, - "busted_retry_llm": int(not self.success), + # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways } diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index dfb7923f..e3de8d01 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -4,6 +4,7 @@ from pathlib import Path from browsergym.experiments.loop import EnvArgs, ExpArgs +from openai import OpenAIError from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs @@ -43,7 +44,7 @@ def test_generic_agent(): @dataclass -class CheatMiniWoBLLM_Retry: +class CheatMiniWoBLLM_ParseRetry: """For unit-testing purposes only. It only work with miniwob.click-test task.""" n_retry: int @@ -74,22 +75,68 @@ def __call__(self, messages) -> str: return self.invoke(messages) def get_stats(self): - return {"n_retry": self.n_retry, "busted_retry": self.retry_count} + return {} @dataclass -class CheatMiniWoBLLMArgs_Retry(BaseModelArgs): +class CheatMiniWoBLLMArgs_ParseRetry(BaseModelArgs): n_retry: int = 2 - model_name: str = "test/cheat_miniwob_click_test_retry" + model_name: str = "test/cheat_miniwob_click_test_parse_retry" def make_model(self): - return CheatMiniWoBLLM_Retry(n_retry=self.n_retry) + return CheatMiniWoBLLM_ParseRetry(n_retry=self.n_retry) -def test_generic_agent_retry(): +@dataclass +class CheatLLM_LLMError: + """For unit-testing purposes only. Fails to call LLM""" + + n_retry: int = 0 + success: bool = False + + def invoke(self, messages) -> str: + if self.success: + prompt = messages[1].get("content", "") + match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) + + if match: + bid = match.group(1) + action = f'click("{bid}")' + else: + raise Exception("Can't find the button's bid") + + answer = f"""I'm clicking the button as requested. + + {action} + + """ + return dict(role="assistant", content=answer) + raise OpenAIError("LLM failed to respond") + + def __call__(self, messages) -> str: + return self.invoke(messages) + + def get_stats(self): + return {"n_llm_retry": self.n_retry, "n_llm_busted_retry": int(not self.success)} + + +@dataclass +class CheatLLMArgs_LLMError(BaseModelArgs): + n_retry: int = 2 + success: bool = False + model_name: str = "test/cheat_miniwob_click_test_parse_retry" + + def make_model(self): + return CheatLLM_LLMError( + n_retry=self.n_retry, + success=self.success, + ) + + +def test_generic_agent_parse_retry(): exp_args = ExpArgs( agent_args=GenericAgentArgs( - chat_model_args=CheatMiniWoBLLMArgs_Retry(n_retry=2), + chat_model_args=CheatMiniWoBLLMArgs_ParseRetry(n_retry=2), flags=FLAGS_GPT_3_5, ), env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), @@ -111,10 +158,10 @@ def test_generic_agent_retry(): assert result_record[key].iloc[0] == target_val -def test_bust_retry(): +def test_bust_parse_retry(): exp_args = ExpArgs( agent_args=GenericAgentArgs( - chat_model_args=CheatMiniWoBLLMArgs_Retry(n_retry=10), + chat_model_args=CheatMiniWoBLLMArgs_ParseRetry(n_retry=10), flags=FLAGS_GPT_3_5, ), env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), @@ -129,6 +176,31 @@ def test_bust_retry(): "stats.cum_busted_retry": 1, "n_steps": 0, "cum_reward": 0, + "err_msg": None, # parsing error is considered an agent failure, not a code error + } + + for key, target_val in target.items(): + assert key in result_record + assert result_record[key].iloc[0] == target_val + + +def test_llm_error_success(): + exp_args = ExpArgs( + agent_args=GenericAgentArgs( + chat_model_args=CheatLLMArgs_LLMError(n_retry=3, success=True), + flags=FLAGS_GPT_3_5, + ), + env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) + + target = { + "stats.cum_n_llm_retry": 3, + "n_steps": 1, + "cum_reward": 1.0, "err_msg": None, } @@ -137,6 +209,30 @@ def test_bust_retry(): assert result_record[key].iloc[0] == target_val +def test_llm_error_no_success(): + exp_args = ExpArgs( + agent_args=GenericAgentArgs( + chat_model_args=CheatLLMArgs_LLMError(n_retry=5, success=False), + flags=FLAGS_GPT_3_5, + ), + env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + launch_exp.run_experiments(1, [exp_args], Path(tmp_dir) / "generic_agent_test") + result_record = inspect_results.load_result_df(tmp_dir, progress_fn=None) + + target = { + "n_steps": 0, + "cum_reward": 0, + "err_msg": "Exception uncaught by agent or environment in task miniwob.click-test.\nOpenAIError:\nLLM failed to respond", + } + + for key, target_val in target.items(): + assert key in result_record + assert result_record[key].iloc[0] == target_val + + if __name__ == "__main__": # test_generic_agent() - test_bust_retry() + test_llm_error_no_success() From b8a151771e762c1e7158732a0a25db8454836195 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Wed, 2 Oct 2024 10:40:24 -0400 Subject: [PATCH 37/37] moving functions around --- .../generic_agent/reproducibility_agent.py | 34 +++++-------------- src/agentlab/llm/llm_utils.py | 17 +++++++++- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index 6c778b44..2d07bf44 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -10,22 +10,22 @@ answers. Load the this reproducibility study in agent-xray to compare the results. """ +import difflib +import logging +import time from copy import copy from dataclasses import dataclass -import logging from pathlib import Path -import time +from browsergym.experiments.agent import AgentInfo +from browsergym.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results from bs4 import BeautifulSoup from agentlab.agents.agent_args import AgentArgs -from .generic_agent import GenericAgentArgs, GenericAgent -from browsergym.experiments.loop import ExpResult, ExpArgs, yield_all_exp_results -from browsergym.experiments.agent import AgentInfo -import difflib +from agentlab.llm.chat_api import make_assistant_message +from agentlab.llm.llm_utils import messages_to_dict -from langchain.schema import BaseMessage, AIMessage -from langchain_community.adapters.openai import convert_message_to_dict +from .generic_agent import GenericAgent, GenericAgentArgs class ReproChatModel: @@ -45,8 +45,7 @@ def invoke(self, messages: list): if len(messages) >= len(self.old_messages): # if for some reason the llm response was not saved - # TODO(thibault): convert this to dict instead of AIMessage in the bye langchain PR. - return AIMessage(content="""None""") + return make_assistant_message("""None""") old_response = self.old_messages[len(messages)] self.new_messages.append(old_response) @@ -108,21 +107,6 @@ def get_action(self, obs): ) -# TODO(thibault): move this to llm utils in bye langchain PR. -def messages_to_dict(messages: list[dict] | list[BaseMessage]) -> dict: - new_messages = [] - for m in messages: - if isinstance(m, dict): - new_messages.append(m) - elif isinstance(m, str): - new_messages.append({"role": "", "content": m}) - elif isinstance(m, BaseMessage): - new_messages.append(convert_message_to_dict(m)) - else: - raise ValueError(f"Unknown message type: {type(m)}") - return new_messages - - def _make_agent_stats(action, agent_info, step_info, old_chat_messages, new_chat_messages): if isinstance(agent_info, dict): agent_info = AgentInfo(**agent_info) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index afbcb27d..4b876b54 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -13,7 +13,8 @@ import numpy as np import tiktoken import yaml -from openai import BadRequestError, RateLimitError +from langchain.schema import BaseMessage +from langchain_community.adapters.openai import convert_message_to_dict from PIL import Image from transformers import AutoModel, AutoTokenizer @@ -21,6 +22,20 @@ from agentlab.llm.chat_api import ChatModel +def messages_to_dict(messages: list[dict] | list[BaseMessage]) -> dict: + new_messages = [] + for m in messages: + if isinstance(m, dict): + new_messages.append(m) + elif isinstance(m, str): + new_messages.append({"role": "", "content": m}) + elif isinstance(m, BaseMessage): + new_messages.append(convert_message_to_dict(m)) + else: + raise ValueError(f"Unknown message type: {type(m)}") + return new_messages + + class RetryError(ValueError): pass