diff --git a/openadapt/events.py b/openadapt/events.py index 67ee7f4c9..c46156cc8 100644 --- a/openadapt/events.py +++ b/openadapt/events.py @@ -18,6 +18,7 @@ MOUSE_MOVE_EVENT_MERGE_DISTANCE_THRESHOLD = 1 MOUSE_MOVE_EVENT_MERGE_MIN_IDX_DELTA = 5 KEYBOARD_EVENTS_MERGE_GROUP_NAMED_KEYS = True +USE_SCREENSHOT_DIFFS = False def get_events( @@ -148,7 +149,7 @@ def make_parent_event( def merge_consecutive_mouse_move_events( - events: list[models.ActionEvent], by_diff_distance: bool = False + events: list[models.ActionEvent], by_diff_distance: bool = USE_SCREENSHOT_DIFFS, ) -> list[models.ActionEvent]: """Merge consecutive mouse move events into a single move event. diff --git a/openadapt/models.py b/openadapt/models.py index 6c7e700b9..2a29f3842 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -270,7 +270,7 @@ def from_dict(cls, action_dict: dict) -> list['ActionEvent']: # TODO: use config.ACTION_TEXT_SEP, ACTION_TEXT_NAME_PREFIX/SUFFIX children = [] release_events = [] - try: + if "text" in action_dict: # Splitting actions based on whether they are special keys or regular characters if action_dict['text'].startswith('<') and action_dict['text'].endswith('>'): # Handling special keys @@ -288,12 +288,10 @@ def from_dict(cls, action_dict: dict) -> list['ActionEvent']: press, release = cls._create_key_events(key_char=key_char) children.append(press) children.append(release) - except KeyError as exc: - # Handle missing key names or canonical text appropriately - logger.warning(f"{exc=}") - - children += release_events[::-1] - return ActionEvent(children=children) + children += release_events[::-1] + return ActionEvent(children=children) + else: + return ActionEvent(**action_dict) @classmethod def _create_key_events( @@ -382,16 +380,24 @@ class Screenshot(db.Base): recording = sa.orm.relationship("Recording", back_populates="screenshots") action_event = sa.orm.relationship("ActionEvent", back_populates="screenshot") - # TODO: convert to png_data on save - sct_img = None - - # TODO: replace prev with prev_timestamp? - prev = None - _image = None - _image_history = [] - _diff = None - _diff_mask = None - _base64 = None + def __init__(self, *args, sct_img=None, **kwargs): + super().__init__(*args, **kwargs) + self.initialize_instance_attributes() + self.sct_img = sct_img + + @sa.orm.reconstructor + def initialize_instance_attributes(self): + """Initialize attributes for both new and loaded objects.""" + # TODO: convert to png_data on save + self.sct_img = None + + # TODO: replace prev with prev_timestamp? + self.prev = None + self._image = None + self._image_history = [] + self._diff = None + self._diff_mask = None + self._base64 = None @property def image(self) -> Image: diff --git a/openadapt/strategies/visual.py b/openadapt/strategies/visual.py index 2c4d188b8..7218ef1e9 100644 --- a/openadapt/strategies/visual.py +++ b/openadapt/strategies/visual.py @@ -45,6 +45,7 @@ DEBUG = False +DEBUG_REPLAY = False MAX_TOKENS = 2**14 # 16384 @@ -64,7 +65,8 @@ def add_active_segment_descriptions( logical_events = [] for action in action_events: # TODO: handle terminal event - if action.name in ("click", "doubleclick", "singleclick", "scroll"): + #if action.name in ("click", "doubleclick", "singleclick", "scroll"): + if action.name in common.MOUSE_EVENTS: window_segmentation = get_window_segmentation(action) active_segment_idx = get_active_segment(action, window_segmentation) if not active_segment_idx: @@ -98,16 +100,19 @@ def apply_replay_instructions( actions=actions_dict, replay_instructions=replay_instructions, ) - import ipdb; ipdb.set_trace() prompt_adapter = get_default_prompt_adapter() content = prompt_adapter.prompt( prompt, system_prompt, max_tokens=max_tokens, ) - import ipdb; ipdb.set_trace() - content = utils.parse_code_snippet(content) - actions = models.ActionEvent.from_dict(content) + content_dict = utils.parse_code_snippet(content) + action_dicts = content_dict["actions"] + modified_actions = [] + for action_dict in action_dicts: + action = models.ActionEvent.from_dict(action_dict) + modified_actions.append(action) + return modified_actions def get_window_prompt_dict(active_window: models.WindowEvent) -> dict: @@ -170,9 +175,11 @@ def __init__( super().__init__(recording) self.recording_action_idx = 0 add_active_segment_descriptions(recording.processed_action_events) - apply_replay_instructions( + self.modified_actions = apply_replay_instructions( recording.processed_action_events, replay_instructions, ) + global DEBUG + DEBUG = DEBUG_REPLAY def get_next_action_event( self, @@ -181,6 +188,9 @@ def get_next_action_event( ) -> models.ActionEvent: """Get the next ActionEvent for replay. + Since we have already modified the actions, this function just determines + the appropriate coordinates for the modified actions (where appropriate). + Args: active_screenshot (models.Screenshot): The active screenshot object. active_window (models.WindowEvent): The active window event object. @@ -192,17 +202,55 @@ def get_next_action_event( if self.recording_action_idx == len(self.recording.processed_action_events): raise StopIteration() + reference_action = self.recording.processed_action_events[self.recording_action_idx] + reference_window = reference_action.window_event + # XXX HACK + while (ref_win_title_prefix := reference_window.title.split(" ")[0]) != (active_win_title_prefix := active_window.title.split(" ")[0]): + logger.warning(f"{ref_win_title_prefix=} != {active_win_title_prefix=}") + import time + time.sleep(1) + active_window = models.WindowEvent.get_active_window_event() + active_screenshot = models.Screenshot.take_screenshot() + logger.info(f"{active_window=}") + + modified_reference_action = self.modified_actions[self.recording_action_idx] + self.recording_action_idx += 1 + + if modified_reference_action.name in common.MOUSE_EVENTS: + modified_reference_action.screenshot = active_screenshot + modified_reference_action.window_event = active_window + modified_reference_action.recording = self.recording + active_window_segmentation = get_window_segmentation( + modified_reference_action, + ) + try: + target_segment_idx = active_window_segmentation.descriptions.index( + modified_reference_action.active_segment_description + ) + except ValueError as exc: + logger.error(exc) + import ipdb; ipdb.set_trace() + # TODO: prompt model to determine closest match + target_centroid = active_window_segmentation.centroids[target_segment_idx] + # = scale_ratio * + width_ratio, height_ratio = utils.get_scale_ratios(modified_reference_action) + target_mouse_x = target_centroid[0] / width_ratio + active_window.left + target_mouse_y = target_centroid[1] / height_ratio + active_window.top + modified_reference_action.mouse_x = target_mouse_x + modified_reference_action.mouse_y = target_mouse_y + return modified_reference_action + + + def _get_next_action_event(): active_window_dict = get_window_prompt_dict(active_window) - actions = self.recording.processed_action_events prev_window_title = None prompt_frames = [] active_action = None active_action_dict = None num_images = 0 screenshots = [] - #screenshots_by_window_title = get_screenshots_by_window_title(actions) - for action_idx, action in enumerate(actions): + for action_idx, action in enumerate(self.modified_actions): window_dict = get_window_prompt_dict(action.window_event) window_title = window_dict["title"] if prev_window_title is None or prev_window_title != window_title: @@ -252,7 +300,7 @@ def get_next_action_event( if reference_action.name in common.MOUSE_EVENTS: active_segmentation = get_window_segmentation( screenshot=active_screenshot, - window_event=actions[0].window_event, + window_event=active_window, ) # TODO: replace screenshots with window segmentations? @@ -363,22 +411,9 @@ def _get_active_segment(action: models.ActionEvent, window_segmentation: Segment def get_window_segmentation( action_event: models.ActionEvent | None = None, - screenshot: models.Screenshot | None = None, - window_event: models.WindowEvent | None = None, - # TODO - #active_element_only: bool = True, ) -> Segmentation: - assert action_event or (screenshot and window_event) - if action_event: - screenshot = action_event.screenshot - screenshot.crop_active_window(action_event) - else: - width_ratio, height_ratio = utils.get_scale_ratios(window_event.action_events[0]) - screenshot.crop_active_window( - window_event=window_event, - width_ratio=width_ratio, - height_ratio=height_ratio, - ) + screenshot = action_event.screenshot + screenshot.crop_active_window(action_event) original_image = screenshot.image if DEBUG: original_image.show() @@ -400,10 +435,9 @@ def get_window_segmentation( for masked_image in masked_images ] descriptions = prompt_for_descriptions( - original_image_base64, masked_images_base64, - ) - assert len(descriptions) == len(masked_images), ( - len(descriptions), len(masked_images) + original_image_base64, + masked_images_base64, + action_event.active_segment_description, ) bounding_boxes, centroids = vision.calculate_bounding_boxes(refined_masks) assert len(bounding_boxes) == len(descriptions) == len(centroids), ( @@ -497,6 +531,7 @@ def prompt_for_action( def prompt_for_descriptions( original_image_base64: str, masked_images_base64: list[str], + active_segment_description: str | None = None, max_tokens: int | None = MAX_TOKENS, ) -> list[str]: images = [original_image_base64] + masked_images_base64 @@ -504,8 +539,11 @@ def prompt_for_descriptions( "openadapt/prompts/system.j2", ) logger.info(f"system_prompt=\n{system_prompt}") + num_segments = len(masked_images_base64) prompt = utils.render_template_from_file( "openadapt/prompts/description.j2", + active_segment_description=active_segment_description, + num_segments=num_segments, ) logger.info(f"prompt=\n{prompt}") prompt_adapter = get_default_prompt_adapter() @@ -517,4 +555,14 @@ def prompt_for_descriptions( ) descriptions = utils.parse_code_snippet(descriptions_json)["descriptions"] logger.info(f"{descriptions=}") + try: + assert len(descriptions) == len(masked_images_base64), ( + len(descriptions), len(masked_images_base64) + ) + except Exception as exc: + logger.error(exc) + import ipdb; ipdb.set_trace() + foo = 1 + # remove indexes + descriptions = [desc for idx, desc in descriptions] return descriptions