Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost tracking for openrouter and openai #31

Merged
merged 21 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
65a8f78
updating dependencies for github workflows
ThibaultLSDC Sep 11, 2024
a5a2000
Merge branch 'main' of github.com:ServiceNow/AgentLab
ThibaultLSDC Sep 13, 2024
510b835
Merge branch 'main' of github.com:ServiceNow/AgentLab into tracking
ThibaultLSDC Sep 17, 2024
ab0f4ec
openrouter tracker poc
ThibaultLSDC Sep 18, 2024
ca50598
adding openai pricing request
ThibaultLSDC Sep 18, 2024
0cde22b
switching back to langchain community for openai pricing
ThibaultLSDC Sep 19, 2024
060e1e7
renaming launch_command.py to main.py
ThibaultLSDC Sep 19, 2024
7eafcf8
Merge branch 'main' into tracking
ThibaultLSDC Sep 19, 2024
600cfca
typo
ThibaultLSDC Sep 19, 2024
708bde5
tracking is thread safe and mostly tested
ThibaultLSDC Sep 20, 2024
7c20945
added pricy tests for ChatModels
ThibaultLSDC Sep 20, 2024
18a45e0
separating get_pricing function
ThibaultLSDC Sep 20, 2024
9d12cdf
updating function names
ThibaultLSDC Sep 20, 2024
8e5e5f9
updating function names
ThibaultLSDC Sep 20, 2024
17e8ff8
ciao retry_parallel
ThibaultLSDC Sep 20, 2024
d62357f
fixing case when the context isnt used
ThibaultLSDC Sep 20, 2024
e3808f9
moving ChatModels to chat_api
ThibaultLSDC Sep 20, 2024
cfe01a0
renaming get_action decorator
ThibaultLSDC Sep 20, 2024
5676f2e
Merge branch 'tracking' of github.com:ServiceNow/AgentLab into tracking
ThibaultLSDC Sep 20, 2024
37fe6b4
Merge branch 'main' of github.com:ServiceNow/AgentLab
ThibaultLSDC Sep 20, 2024
cd3069e
Merge branch 'main' into tracking
ThibaultLSDC Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/agentlab/experiments/launch_command.py → main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@

import logging

from agentlab.agents.generic_agent import RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI
from agentlab.agents.generic_agent import AGENT_CUSTOM, RANDOM_SEARCH_AGENT, AGENT_4o, AGENT_4o_MINI
from agentlab.analyze.inspect_results import get_most_recent_folder
from agentlab.experiments import study_generators
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments.launch_exp import make_study_dir, run_experiments, relaunch_study
from agentlab.experiments.launch_exp import make_study_dir, relaunch_study, run_experiments

logging.getLogger().setLevel(logging.INFO)

# choose your agent or provide a new agent
agent_args = AGENT_4o_MINI
# agent_args = AGENT_4o
agent_args = [AGENT_4o_MINI]
# agent = AGENT_4o


## select the benchmark to run on
benchmark = "miniwob"
benchmark = "miniwob_tiny_test"
# benchmark = "miniwob"
# benchmark = "workarena.l1"
# benchmark = "workarena.l2"
# benchmark = "workarena.l3"
Expand All @@ -45,8 +46,6 @@
n_jobs = 1 # Make sure to use 1 job when debugging in VSCode
# n_jobs = -1 # to use all available cores


# Run the experiments
# run the experiments
if __name__ == "__main__":

run_experiments(n_jobs, exp_args_list, study_dir, parallel_backend="dask")
run_experiments(n_jobs, exp_args_list, study_dir)
2 changes: 2 additions & 0 deletions src/agentlab/agents/generic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AGENT_3_5,
AGENT_8B,
AGENT_70B,
AGENT_CUSTOM,
RANDOM_SEARCH_AGENT,
AGENT_4o,
AGENT_4o_MINI,
Expand All @@ -16,4 +17,5 @@
"AGENT_70B",
"AGENT_8B",
"RANDOM_SEARCH_AGENT",
"AGENT_CUSTOM",
]
8 changes: 4 additions & 4 deletions src/agentlab/agents/generic_agent/agent_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .generic_agent_prompt import GenericPromptFlags
from agentlab.agents import dynamic_prompting as dp
from .generic_agent import GenericAgentArgs
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
from agentlab.experiments import args
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .generic_agent import GenericAgentArgs
from .generic_agent_prompt import GenericPromptFlags

FLAGS_CUSTOM = GenericPromptFlags(
obs=dp.ObsFlags(
Expand Down Expand Up @@ -45,7 +45,7 @@


AGENT_CUSTOM = GenericAgentArgs(
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-3.5-turbo-1106"],
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"],
recursix marked this conversation as resolved.
Show resolved Hide resolved
flags=FLAGS_CUSTOM,
)

Expand Down
3 changes: 2 additions & 1 deletion src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from agentlab.agents.utils import openai_monitored_agent
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import RetryError, retry_raise
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt

Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
def obs_preprocessor(self, obs: dict) -> dict:
return self._obs_preprocessor(obs)

@openai_monitored_agent
@cost_tracker_decorator
def get_action(self, obs):

self.obs_history.append(obs)
Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/agents/most_basic_agent/most_basic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry_raise
from agentlab.llm.tracking import cost_tracker_decorator

if TYPE_CHECKING:
from agentlab.llm.chat_api import BaseModelArgs
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(

self.action_set = HighLevelActionSet(["bid"], multiaction=False)

@cost_tracker_decorator
def get_action(self, obs: Any) -> tuple[str, dict]:
system_prompt = f"""
You are a web assistant.
Expand Down
124 changes: 120 additions & 4 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from typing import TYPE_CHECKING

from langchain.schema import AIMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI

import agentlab.llm.tracking as tracking
from agentlab.llm.langchain_utils import (
ChatOpenRouter,
HuggingFaceAPIChatModel,
HuggingFaceURLChatModel,
_convert_messages_to_dict,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,7 +87,7 @@ class OpenRouterModelArgs(BaseModelArgs):
model."""

def make_model(self):
return ChatOpenRouter(
return OpenRouterChatModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
Expand All @@ -99,7 +100,7 @@ class OpenAIModelArgs(BaseModelArgs):
model."""

def make_model(self):
return ChatOpenAI(
return OpenAIChatModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
Expand All @@ -126,7 +127,7 @@ class AzureModelArgs(BaseModelArgs):
deployment_name: str = None

def make_model(self):
return AzureChatOpenAI(
return AzureChatModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
Expand Down Expand Up @@ -194,3 +195,118 @@ def __post_init__(self):

def make_model(self):
pass


class ChatModel(ABC):

@abstractmethod
def __init__(self, model_name, api_key=None, temperature=0.5, max_tokens=100):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens

self.client = tracking.OpenAI()

self.input_cost = 0.0
self.output_cost = 0.0

def __call__(self, messages: list[dict]) -> dict:
messages_formatted = _convert_messages_to_dict(messages)
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages_formatted,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
cost = input_tokens * self.input_cost + output_tokens * self.output_cost

if isinstance(tracking.TRACKER.instance, tracking.LLMTracker):
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

return AIMessage(content=completion.choices[0].message.content)

def invoke(self, messages: list[dict]) -> dict:
return self(messages)


class OpenRouterChatModel(ChatModel):
def __init__(
self,
model_name,
api_key=None,
temperature=0.5,
max_tokens=100,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens

api_key = api_key or os.getenv("OPENROUTER_API_KEY")

pricings = tracking.get_pricing_openrouter()

self.input_cost = pricings[model_name]["prompt"]
self.output_cost = pricings[model_name]["completion"]

self.client = tracking.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
)


class OpenAIChatModel(ChatModel):
def __init__(
self,
model_name,
api_key=None,
temperature=0.5,
max_tokens=100,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens

api_key = api_key or os.getenv("OPENAI_API_KEY")

pricings = tracking.get_pricing_openai()

self.input_cost = float(pricings[model_name]["prompt"])
self.output_cost = float(pricings[model_name]["completion"])

self.client = tracking.OpenAI(
api_key=api_key,
)


class AzureChatModel(ChatModel):
def __init__(
self,
model_name,
api_key=None,
deployment_name=None,
temperature=0.5,
max_tokens=100,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens

api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")

# AZURE_OPENAI_ENDPOINT has to be defined in the environment
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment"

pricings = tracking.get_pricing_openai()

self.input_cost = float(pricings[model_name]["prompt"])
self.output_cost = float(pricings[model_name]["completion"])

self.client = tracking.AzureOpenAI(
api_key=api_key,
azure_deployment=deployment_name,
azure_endpoint=endpoint,
api_version="2024-02-01",
)
63 changes: 0 additions & 63 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,69 +177,6 @@ def retry_raise(
raise RetryError(f"Could not parse a valid value after {n_retry} retries.")


def retry_parallel(chat: "BaseChatModel", messages, n_retry, parser):
"""Retry querying the chat models with the response from the parser until it returns a valid value.

It will stop after `n_retry`. It assuemes that chat will generate n_parallel answers for each message.
The best answer is selected according to the score returned by the parser. If no answer is valid, the
it will retry with the best answer so far and append to the chat the retry message. If there is a
single parallel generation, it behaves like retry.

This function is, in principle, more robust than retry. The speed and cost overhead is minimal with
the prompt is large and the length of the generated message is small.

Args:
chat (BaseChatModel): a langchain BaseChatModel taking a list of messages and
returning a list of answers.
messages (list): the list of messages so far.
n_retry (int): the maximum number of sequential retries.
parser (function): a function taking a message and returning a tuple
with the following fields:
value : the parsed value,
valid : a boolean indicating if the value is valid,
retry_message : a message to send to the chat if the value is not valid

Returns:
dict: the parsed value, with a string at key "action".

Raises:
ValueError: if the parser could not parse a valid value after n_retry retries.
BadRequestError: if the message is too long
"""

for i in range(n_retry):
try:
answers = chat.generate([messages]).generations[0] # chat.n parallel completions
except BadRequestError as e:
# most likely, the added messages triggered a message too long error
# we thus retry without the last two messages
if i == 0:
raise e
msg = f"BadRequestError, most likely the message is too long retrying with previous query."
warn(msg)
messages = messages[:-2]
answers = chat.generate([messages]).generations[0]

values, valids, retry_messages, scores = zip(
*[parser(answer.message.content) for answer in answers]
)
idx = np.argmax(scores)
value = values[idx]
valid = valids[idx]
retry_message = retry_messages[idx]
answer = answers[idx].message

if valid:
return value

msg = f"Query failed. Retrying {i+1}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}"
warn(msg)
messages.append(answer) # already of type AIMessage
messages.append(SystemMessage(content=retry_message))

raise ValueError(f"Could not parse a valid value after {n_retry} retries.")


def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"):
"""Use tiktoken to truncate a text to a maximum number of tokens."""
enc = tiktoken.encoding_for_model(model_name)
Expand Down
Loading