Skip to content

Commit

Permalink
feat(models): add ActionEvent.prompt_for_description (#933)
Browse files Browse the repository at this point in the history
* add ActionEvent.prompt_for_description

* add display_event(darken_outside, display_text)

* add experiments/describe_action.py

* default RECORD_AUDIO to false

* use joinedload in get_latest_recording

* set anthropic.py MODEL_NAME to claude-3-5-sonnet-20241022

* support PNG in utils.image2utf8

* python>=3.10,<3.12
  • Loading branch information
abrichr authored Jan 2, 2025
1 parent e595dd3 commit 266b9bf
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 51 deletions.
116 changes: 116 additions & 0 deletions experiments/describe_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Generate action descriptions."""

from pprint import pformat

from loguru import logger
import cv2
import numpy as np

from openadapt.db import crud


def embed_description(
image: np.ndarray,
description: str,
x: int = None,
y: int = None,
) -> np.ndarray:
"""Embed a description into an image at the specified location.
Args:
image (np.ndarray): The image to annotate.
description (str): The text to embed.
x (int, optional): The x-coordinate. Defaults to None (centered).
y (int, optional): The y-coordinate. Defaults to None (centered).
Returns:
np.ndarray: The annotated image.
"""
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255) # White
line_type = 1

# Split description into multiple lines
max_width = 60 # Maximum characters per line
words = description.split()
lines = []
current_line = []
for word in words:
if len(" ".join(current_line + [word])) <= max_width:
current_line.append(word)
else:
lines.append(" ".join(current_line))
current_line = [word]
if current_line:
lines.append(" ".join(current_line))

# Default to center if coordinates are not provided
if x is None or y is None:
x = image.shape[1] // 2
y = image.shape[0] // 2

# Draw semi-transparent background and text
for i, line in enumerate(lines):
text_size, _ = cv2.getTextSize(line, font, font_scale, line_type)
text_x = max(0, min(x - text_size[0] // 2, image.shape[1] - text_size[0]))
text_y = y + i * 20

# Draw background
cv2.rectangle(
image,
(text_x - 15, text_y - 25),
(text_x + text_size[0] + 15, text_y + 15),
(0, 0, 0),
-1,
)

# Draw text
cv2.putText(
image,
line,
(text_x, text_y),
font,
font_scale,
font_color,
line_type,
)

return image


def main() -> None:
"""Main function."""
with crud.get_new_session(read_only=True) as session:
recording = crud.get_latest_recording(session)
action_events = recording.processed_action_events
descriptions = []
for action in action_events:
description, image = action.prompt_for_description(return_image=True)

# Convert image to numpy array for OpenCV compatibility
image = np.array(image)

if action.mouse_x is not None and action.mouse_y is not None:
# Use the mouse coordinates for mouse events
annotated_image = embed_description(
image,
description,
x=int(action.mouse_x) * 2,
y=int(action.mouse_y) * 2,
)
else:
# Center the text for other events
annotated_image = embed_description(image, description)

logger.info(f"{action=}")
logger.info(f"{description=}")
cv2.imshow("Annotated Image", annotated_image)
cv2.waitKey(0)
descriptions.append(description)

logger.info(f"descriptions=\n{pformat(descriptions)}")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion openadapt/config.defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
"REPLAY_STRIP_ELEMENT_STATE": true,
"RECORD_VIDEO": true,
"RECORD_AUDIO": true,
"RECORD_AUDIO": false,
"RECORD_BROWSER_EVENTS": false,
"RECORD_FULL_VIDEO": false,
"RECORD_IMAGES": false,
Expand Down
20 changes: 11 additions & 9 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,18 @@ def get_all_scrubbed_recordings(


def get_latest_recording(session: SaSession) -> Recording:
"""Get the latest recording.
Args:
session (sa.orm.Session): The database session.
Returns:
Recording: The latest recording object.
"""
"""Get the latest recording with preloaded relationships."""
return (
session.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
session.query(Recording)
.options(
sa.orm.joinedload(Recording.screenshots),
sa.orm.joinedload(Recording.action_events)
.joinedload(ActionEvent.screenshot)
.joinedload(Screenshot.recording),
sa.orm.joinedload(Recording.window_events),
)
.order_by(sa.desc(Recording.timestamp))
.first()
)


Expand Down
10 changes: 6 additions & 4 deletions openadapt/drivers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from PIL import Image
import anthropic

from openadapt import cache, utils
from openadapt import cache
from openadapt.config import config
from openadapt.custom_logger import logger

MAX_TOKENS = 4096
# from https://docs.anthropic.com/claude/docs/vision
MAX_IMAGES = 20
MODEL_NAME = "claude-3-opus-20240229"
MODEL_NAME = "claude-3-5-sonnet-20241022"


@cache.cache()
Expand All @@ -24,6 +24,8 @@ def create_payload(
max_tokens: int | None = None,
) -> dict:
"""Creates the payload for the Anthropic API request with image support."""
from openadapt import utils

messages = []

user_message_content = []
Expand All @@ -36,7 +38,7 @@ def create_payload(
# Add base64 encoded images to the user message content
if images:
for image in images:
image_base64 = utils.image2utf8(image)
image_base64 = utils.image2utf8(image, "PNG")
# Extract media type and base64 data
# TODO: don't add it to begin with
media_type, image_base64_data = image_base64.split(";base64,", 1)
Expand Down Expand Up @@ -90,7 +92,7 @@ def get_completion(
"""Sends a request to the Anthropic API and returns the response."""
client = anthropic.Anthropic(api_key=api_key)
try:
response = client.messages.create(**payload)
response = client.beta.messages.create(**payload)
except Exception as exc:
logger.exception(exc)
if dev_mode:
Expand Down
80 changes: 80 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import io
import sys
import textwrap

from bs4 import BeautifulSoup
from pynput import keyboard
Expand All @@ -16,6 +17,7 @@

from openadapt.config import config
from openadapt.custom_logger import logger
from openadapt.drivers import anthropic
from openadapt.db import db
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
from openadapt.privacy.providers import ScrubProvider
Expand Down Expand Up @@ -110,6 +112,9 @@ def processed_action_events(self) -> list:
if not self._processed_action_events:
session = crud.get_new_session(read_only=True)
self._processed_action_events = events.get_events(session, self)
# Preload screenshots to avoid lazy loading later
for event in self._processed_action_events:
event.screenshot
return self._processed_action_events

def scrub(self, scrubber: ScrubbingProvider) -> None:
Expand All @@ -125,6 +130,7 @@ class ActionEvent(db.Base):
"""Class representing an action event in the database."""

__tablename__ = "action_event"
_repr_ignore_attrs = ["reducer_names"]

_segment_description_separator = ";"

Expand Down Expand Up @@ -333,6 +339,11 @@ def canonical_text(self, value: str) -> None:
if not value == self.canonical_text:
logger.warning(f"{value=} did not match {self.canonical_text=}")

@property
def raw_text(self) -> str:
"""Return a string containing the raw action text (without separators)."""
return "".join(self.text.split(config.ACTION_TEXT_SEP))

def __str__(self) -> str:
"""Return a string representation of the action event."""
attr_names = [
Expand Down Expand Up @@ -544,6 +555,75 @@ def next_event(self) -> Union["ActionEvent", None]:

return None

def prompt_for_description(self, return_image: bool = False) -> str:
"""Use the Anthropic API to describe what is happening in the action event.
Args:
return_image (bool): Whether to return the image sent to the model.
Returns:
str: The description of the action event.
"""
from openadapt.plotting import display_event

image = display_event(
self,
marker_width_pct=0.05,
marker_height_pct=0.05,
darken_outside=0.7,
display_text=False,
marker_fill_transparency=0,
)

if self.text:
description = f"Type '{self.raw_text}'"
else:
prompt = (
"What user interface element is contained in the highlighted circle "
"of the image?"
)
# TODO: disambiguate
system_prompt = textwrap.dedent(
"""
Briefly describe the user interface element in the screenshot at the
highlighted location.
For example:
- "OK button"
- "URL bar"
- "Down arrow"
DO NOT DESCRIBE ANYTHING OUTSIDE THE HIGHLIGHTED AREA.
Do not append anything like "is contained within the highlighted circle
in the calculator interface." Just name the user interface element.
"""
)

logger.info(f"system_prompt=\n{system_prompt}")
logger.info(f"prompt=\n{prompt}")

# Call the Anthropic API
element = anthropic.prompt(
prompt=prompt,
system_prompt=system_prompt,
images=[image],
)

if self.name == "move":
description = f"Move mouse to '{element}'"
elif self.name == "scroll":
# TODO: "scroll to", dx/dy
description = f"Scroll mouse on '{element}'"
elif "click" in self.name:
description = (
f"{self.mouse_button_name.capitalize()} {self.name} '{element}'"
)
else:
raise ValueError(f"Unhandled {self.name=} {self}")

if return_image:
return description, image
else:
return description


class WindowEvent(db.Base):
"""Class representing a window event in the database."""
Expand Down
Loading

0 comments on commit 266b9bf

Please sign in to comment.