Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

469: agent profiler, waterfall chart, step summary table #587

24 changes: 23 additions & 1 deletion lavague-core/lavague/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from lavague.core.token_counter import TokenCounter
from lavague.core.utilities.config import is_flag_true

from lavague.core.utilities.profiling import (
ChartGenerator,
time_profiler,
start_new_step,
clear_profiling_data,
)

logging_print = logging.getLogger(__name__)
logging_print.setLevel(logging.INFO)
format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
Expand Down Expand Up @@ -503,7 +510,9 @@ def run(

try:
for _ in range(self.n_steps):
result = self.run_step(objective)
start_new_step()
with time_profiler("Run step", full_step_profiling=True):
result = self.run_step(objective)

if result is not None:
break
Expand Down Expand Up @@ -598,3 +607,16 @@ def display_all_nodes(self) -> None:

def set_origin(self, origin: str):
self.origin = origin

def get_summary(self):
from lavague.core.utilities.profiling import agent_events, agent_steps

chart_generator = ChartGenerator(
agent_events=agent_events, agent_steps=agent_steps
)
plot = chart_generator.plot_waterfall()
table = chart_generator.get_summary_df()

clear_profiling_data()

return plot, table
20 changes: 15 additions & 5 deletions lavague-core/lavague/core/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from PIL import Image
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.embeddings import BaseEmbedding
from lavague.core.utilities.profiling import time_profiler

NAVIGATION_ENGINE_PROMPT_TEMPLATE = ActionTemplate(
"""
Expand Down Expand Up @@ -154,9 +155,13 @@ def get_nodes(self, query: str) -> List[str]:
`List[str]`: The nodes
"""
viewport_only = not self.driver.previously_scanned
source_nodes = self.retriever.retrieve(
QueryBundle(query_str=query), [self.driver.get_html()], viewport_only
)

html = self.driver.get_html()

with time_profiler("Retriever Inference", html_size=len(html)):
source_nodes = self.retriever.retrieve(
QueryBundle(query_str=query), [html], viewport_only
)
return source_nodes

def add_knowledge(self, knowledge: str):
Expand Down Expand Up @@ -452,7 +457,10 @@ def execute_instruction(self, instruction: str) -> ActionResult:
query_str=instruction,
authorized_xpaths=authorized_xpaths,
)
response = self.llm.complete(prompt).text

with time_profiler("Navigation Engine Inference", prompt_size=len(prompt)):
response = self.llm.complete(prompt).text

end = time.time()
action_generation_time = end - start
action_outcome = {
Expand All @@ -476,7 +484,9 @@ def execute_instruction(self, instruction: str) -> ActionResult:
for item in vision_data:
display_screenshot(item["screenshot"])
time.sleep(0.2)
self.driver.exec_code(action)

with time_profiler("Execute Code"):
self.driver.exec_code(action)
time.sleep(self.time_between_actions)
if self.display:
try:
Expand Down
200 changes: 200 additions & 0 deletions lavague-core/lavague/core/utilities/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import time
import pandas as pd
import matplotlib.pyplot as plt
from functools import wraps
import io
from IPython.display import Image
from itertools import cycle
from contextlib import contextmanager

# stores llm and retriever calls
agent_events = []

# stores total runtime of each step
agent_steps = []


# call before each agent step to group events by steps
def start_new_step():
global agent_events
agent_events.append([])


def clear_profiling_data():
global agent_events, agent_steps
agent_events = []
agent_steps = []


@contextmanager
def time_profiler(
event_name, prompt_size=None, html_size=None, full_step_profiling=False
):
"""
A context manager to profile the execution time of code blocks.

Parameters:
- event_name: The name of the event being profiled.
- prompt_size: Optional size of the prompt, if applicable.
- html_size: Optional size of the HTML, if applicable.
- full_step_profiling: Boolean indicating whether to profile full steps or individual events.
"""
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
duration = end_time - start_time

# create profiling record
record = {
"event_name": event_name,
"start_time": start_time,
"duration": duration,
**({"prompt_size": prompt_size} if prompt_size is not None else {}),
**({"html_size": html_size} if html_size is not None else {}),
}

# append the record to the appropriate list
if full_step_profiling:
agent_steps.append(record)
else:
agent_events[-1].append(record)


class ChartGenerator:
def __init__(self, agent_events, agent_steps):
self.agent_events = agent_events
self.total_step_runtime = agent_steps
self.step_color = "grey"
self.event_color_scheme = [
"#FFB3B3", # Pastel Red
"#ADD8E6", # Pastel Blue
"#B2D8B2", # Pastel Green
"#FFCC99", # Pastel Orange
"#D1B3FF", # Pastel Purple
"#FFB3DE", # Pastel Pink
"#B3FFFF", # Pastel Cyan
"#FFFFB3", # Pastel Yellow
"#FFB3FF", # Pastel Magenta
"#D2B48C", # Pastel Brown
]

def plot_waterfall(self):
# Calculate the earliest start time to align the x-axis to 0
base_start_time = self.total_step_runtime[0]["start_time"]

plt.figure(figsize=(20, 8))
ax = plt.gca()

color_cycle = cycle(self.event_color_scheme)
event_colors = {}

# Plot total step runtime (from run_step)
for i, step in enumerate(self.total_step_runtime):
duration = step["duration"]
start_time = step["start_time"] - base_start_time # Normalize to 0

ax.barh(i, duration, left=start_time, color=self.step_color)
ax.text(
start_time + duration / 2,
i - 0.45,
f"{duration:.2f}s",
ha="center",
va="center",
)

# Plot each individual event on top of the step runtime
for step_index, step in enumerate(self.agent_events):
for event in step:
duration = event["duration"]
event_name = event["event_name"]
start_time = event["start_time"] - base_start_time # Normalize to 0

if event_name not in event_colors:
event_colors[event_name] = next(color_cycle)

color = event_colors[event_name]
ax.barh(step_index, duration, left=start_time, color=color, alpha=1)
ax.text(
start_time + duration / 2,
step_index,
f"{duration:.2f}s",
ha="center",
va="center",
fontsize=9,
color="black",
rotation=90,
)

ax.invert_yaxis()
ax.set_yticks(range(len(self.total_step_runtime)))
ax.set_yticklabels([f"Step {i+1}" for i in range(len(self.total_step_runtime))])
ax.set_xlabel("Time (seconds)")
ax.set_title("Agent Event Waterfall")

# Add legend for event colors
# Existing legend labels
legend_labels = [
plt.Line2D([0], [0], color=color, lw=4) for color in event_colors.values()
]

# Adding the "Step" label in grey
step_label = plt.Line2D([0], [0], color="grey", lw=4)
legend_labels.append(step_label)

# Update the legend to include "Step"
ax.legend(
legend_labels, list(event_colors.keys()) + ["Step"], title="Event Name"
)

# Save to buffer
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)

plt.close()

return Image(buf.read())

def get_summary_df(self):
summary_data = {}

# Iterate over each step and each event in the step
for step_index, step_events in enumerate(self.agent_events):
# Step row key
step_key = f"Step {step_index + 1}"
summary_data[step_key] = {}

# Count the number of each event type in the step
event_counts = {}

# Iterate over each event in the step
for event in step_events:
event_name = event["event_name"]

# Increment the count for the event
if event_name not in event_counts:
event_counts[event_name] = 1
else:
event_counts[event_name] += 1

# for each key in the event, excluding 'event_name', 'start_time', and 'end_time', add the value to the summary
for key, value in event.items():
if key not in ["event_name", "start_time", "end_time"]:
metric_key = f"{event_name} {key}"

if metric_key not in summary_data[step_key]:
summary_data[step_key][metric_key] = value
else:
summary_data[step_key][metric_key] += value

# add the event counts
for event_name, count in event_counts.items():
count_key = f"{event_name} count"
summary_data[step_key][count_key] = count

# Convert the dictionary to a DataFrame
df = pd.DataFrame(summary_data).T

return df
8 changes: 7 additions & 1 deletion lavague-core/lavague/core/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lavague.core.utilities.model_utils import get_model_name
import time
import yaml
from lavague.core.utilities.profiling import time_profiler

WORLD_MODEL_GENERAL_EXAMPLES = """
Objective: Go to the first issue you can find
Expand Down Expand Up @@ -430,7 +431,12 @@ def get_instruction(
)

start = time.time()
mm_llm_output = mm_llm.complete(prompt, image_documents=image_documents).text

with time_profiler("World Model Inference", prompt_size=len(prompt)):
mm_llm_output = mm_llm.complete(
prompt, image_documents=image_documents
).text

end = time.time()
world_model_inference_time = end - start

Expand Down
Loading