Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Apr 7, 2024
1 parent b7acb4d commit bd57862
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 47 deletions.
3 changes: 2 additions & 1 deletion openadapt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MOUSE_MOVE_EVENT_MERGE_DISTANCE_THRESHOLD = 1
MOUSE_MOVE_EVENT_MERGE_MIN_IDX_DELTA = 5
KEYBOARD_EVENTS_MERGE_GROUP_NAMED_KEYS = True
USE_SCREENSHOT_DIFFS = False


def get_events(
Expand Down Expand Up @@ -148,7 +149,7 @@ def make_parent_event(


def merge_consecutive_mouse_move_events(
events: list[models.ActionEvent], by_diff_distance: bool = False
events: list[models.ActionEvent], by_diff_distance: bool = USE_SCREENSHOT_DIFFS,
) -> list[models.ActionEvent]:
"""Merge consecutive mouse move events into a single move event.
Expand Down
40 changes: 23 additions & 17 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def from_dict(cls, action_dict: dict) -> list['ActionEvent']:
# TODO: use config.ACTION_TEXT_SEP, ACTION_TEXT_NAME_PREFIX/SUFFIX
children = []
release_events = []
try:
if "text" in action_dict:
# Splitting actions based on whether they are special keys or regular characters
if action_dict['text'].startswith('<') and action_dict['text'].endswith('>'):
# Handling special keys
Expand All @@ -288,12 +288,10 @@ def from_dict(cls, action_dict: dict) -> list['ActionEvent']:
press, release = cls._create_key_events(key_char=key_char)
children.append(press)
children.append(release)
except KeyError as exc:
# Handle missing key names or canonical text appropriately
logger.warning(f"{exc=}")

children += release_events[::-1]
return ActionEvent(children=children)
children += release_events[::-1]
return ActionEvent(children=children)
else:
return ActionEvent(**action_dict)

@classmethod
def _create_key_events(
Expand Down Expand Up @@ -382,16 +380,24 @@ class Screenshot(db.Base):
recording = sa.orm.relationship("Recording", back_populates="screenshots")
action_event = sa.orm.relationship("ActionEvent", back_populates="screenshot")

# TODO: convert to png_data on save
sct_img = None

# TODO: replace prev with prev_timestamp?
prev = None
_image = None
_image_history = []
_diff = None
_diff_mask = None
_base64 = None
def __init__(self, *args, sct_img=None, **kwargs):
super().__init__(*args, **kwargs)
self.initialize_instance_attributes()
self.sct_img = sct_img

@sa.orm.reconstructor
def initialize_instance_attributes(self):
"""Initialize attributes for both new and loaded objects."""
# TODO: convert to png_data on save
self.sct_img = None

# TODO: replace prev with prev_timestamp?
self.prev = None
self._image = None
self._image_history = []
self._diff = None
self._diff_mask = None
self._base64 = None

@property
def image(self) -> Image:
Expand Down
106 changes: 77 additions & 29 deletions openadapt/strategies/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@


DEBUG = False
DEBUG_REPLAY = False
MAX_TOKENS = 2**14 # 16384


Expand All @@ -64,7 +65,8 @@ def add_active_segment_descriptions(
logical_events = []
for action in action_events:
# TODO: handle terminal <tab> event
if action.name in ("click", "doubleclick", "singleclick", "scroll"):
#if action.name in ("click", "doubleclick", "singleclick", "scroll"):
if action.name in common.MOUSE_EVENTS:
window_segmentation = get_window_segmentation(action)
active_segment_idx = get_active_segment(action, window_segmentation)
if not active_segment_idx:
Expand Down Expand Up @@ -98,16 +100,19 @@ def apply_replay_instructions(
actions=actions_dict,
replay_instructions=replay_instructions,
)
import ipdb; ipdb.set_trace()
prompt_adapter = get_default_prompt_adapter()
content = prompt_adapter.prompt(
prompt,
system_prompt,
max_tokens=max_tokens,
)
import ipdb; ipdb.set_trace()
content = utils.parse_code_snippet(content)
actions = models.ActionEvent.from_dict(content)
content_dict = utils.parse_code_snippet(content)
action_dicts = content_dict["actions"]
modified_actions = []
for action_dict in action_dicts:
action = models.ActionEvent.from_dict(action_dict)
modified_actions.append(action)
return modified_actions


def get_window_prompt_dict(active_window: models.WindowEvent) -> dict:
Expand Down Expand Up @@ -170,9 +175,11 @@ def __init__(
super().__init__(recording)
self.recording_action_idx = 0
add_active_segment_descriptions(recording.processed_action_events)
apply_replay_instructions(
self.modified_actions = apply_replay_instructions(
recording.processed_action_events, replay_instructions,
)
global DEBUG
DEBUG = DEBUG_REPLAY

def get_next_action_event(
self,
Expand All @@ -181,6 +188,9 @@ def get_next_action_event(
) -> models.ActionEvent:
"""Get the next ActionEvent for replay.
Since we have already modified the actions, this function just determines
the appropriate coordinates for the modified actions (where appropriate).
Args:
active_screenshot (models.Screenshot): The active screenshot object.
active_window (models.WindowEvent): The active window event object.
Expand All @@ -192,17 +202,55 @@ def get_next_action_event(
if self.recording_action_idx == len(self.recording.processed_action_events):
raise StopIteration()

reference_action = self.recording.processed_action_events[self.recording_action_idx]
reference_window = reference_action.window_event
# XXX HACK
while (ref_win_title_prefix := reference_window.title.split(" ")[0]) != (active_win_title_prefix := active_window.title.split(" ")[0]):
logger.warning(f"{ref_win_title_prefix=} != {active_win_title_prefix=}")
import time
time.sleep(1)
active_window = models.WindowEvent.get_active_window_event()
active_screenshot = models.Screenshot.take_screenshot()
logger.info(f"{active_window=}")

modified_reference_action = self.modified_actions[self.recording_action_idx]
self.recording_action_idx += 1

if modified_reference_action.name in common.MOUSE_EVENTS:
modified_reference_action.screenshot = active_screenshot
modified_reference_action.window_event = active_window
modified_reference_action.recording = self.recording
active_window_segmentation = get_window_segmentation(
modified_reference_action,
)
try:
target_segment_idx = active_window_segmentation.descriptions.index(
modified_reference_action.active_segment_description
)
except ValueError as exc:
logger.error(exc)
import ipdb; ipdb.set_trace()
# TODO: prompt model to determine closest match
target_centroid = active_window_segmentation.centroids[target_segment_idx]
# <position in image space> = scale_ratio * <position in window/action space>
width_ratio, height_ratio = utils.get_scale_ratios(modified_reference_action)
target_mouse_x = target_centroid[0] / width_ratio + active_window.left
target_mouse_y = target_centroid[1] / height_ratio + active_window.top
modified_reference_action.mouse_x = target_mouse_x
modified_reference_action.mouse_y = target_mouse_y
return modified_reference_action


def _get_next_action_event():
active_window_dict = get_window_prompt_dict(active_window)

actions = self.recording.processed_action_events
prev_window_title = None
prompt_frames = []
active_action = None
active_action_dict = None
num_images = 0
screenshots = []
#screenshots_by_window_title = get_screenshots_by_window_title(actions)
for action_idx, action in enumerate(actions):
for action_idx, action in enumerate(self.modified_actions):
window_dict = get_window_prompt_dict(action.window_event)
window_title = window_dict["title"]
if prev_window_title is None or prev_window_title != window_title:
Expand Down Expand Up @@ -252,7 +300,7 @@ def get_next_action_event(
if reference_action.name in common.MOUSE_EVENTS:
active_segmentation = get_window_segmentation(
screenshot=active_screenshot,
window_event=actions[0].window_event,
window_event=active_window,
)

# TODO: replace screenshots with window segmentations?
Expand Down Expand Up @@ -363,22 +411,9 @@ def _get_active_segment(action: models.ActionEvent, window_segmentation: Segment

def get_window_segmentation(
action_event: models.ActionEvent | None = None,
screenshot: models.Screenshot | None = None,
window_event: models.WindowEvent | None = None,
# TODO
#active_element_only: bool = True,
) -> Segmentation:
assert action_event or (screenshot and window_event)
if action_event:
screenshot = action_event.screenshot
screenshot.crop_active_window(action_event)
else:
width_ratio, height_ratio = utils.get_scale_ratios(window_event.action_events[0])
screenshot.crop_active_window(
window_event=window_event,
width_ratio=width_ratio,
height_ratio=height_ratio,
)
screenshot = action_event.screenshot
screenshot.crop_active_window(action_event)
original_image = screenshot.image
if DEBUG:
original_image.show()
Expand All @@ -400,10 +435,9 @@ def get_window_segmentation(
for masked_image in masked_images
]
descriptions = prompt_for_descriptions(
original_image_base64, masked_images_base64,
)
assert len(descriptions) == len(masked_images), (
len(descriptions), len(masked_images)
original_image_base64,
masked_images_base64,
action_event.active_segment_description,
)
bounding_boxes, centroids = vision.calculate_bounding_boxes(refined_masks)
assert len(bounding_boxes) == len(descriptions) == len(centroids), (
Expand Down Expand Up @@ -497,15 +531,19 @@ def prompt_for_action(
def prompt_for_descriptions(
original_image_base64: str,
masked_images_base64: list[str],
active_segment_description: str | None = None,
max_tokens: int | None = MAX_TOKENS,
) -> list[str]:
images = [original_image_base64] + masked_images_base64
system_prompt = utils.render_template_from_file(
"openadapt/prompts/system.j2",
)
logger.info(f"system_prompt=\n{system_prompt}")
num_segments = len(masked_images_base64)
prompt = utils.render_template_from_file(
"openadapt/prompts/description.j2",
active_segment_description=active_segment_description,
num_segments=num_segments,
)
logger.info(f"prompt=\n{prompt}")
prompt_adapter = get_default_prompt_adapter()
Expand All @@ -517,4 +555,14 @@ def prompt_for_descriptions(
)
descriptions = utils.parse_code_snippet(descriptions_json)["descriptions"]
logger.info(f"{descriptions=}")
try:
assert len(descriptions) == len(masked_images_base64), (
len(descriptions), len(masked_images_base64)
)
except Exception as exc:
logger.error(exc)
import ipdb; ipdb.set_trace()
foo = 1
# remove indexes
descriptions = [desc for idx, desc in descriptions]
return descriptions

0 comments on commit bd57862

Please sign in to comment.