Skip to content

Commit

Permalink
wip (working completions)
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Apr 5, 2024
1 parent 04e317a commit b7acb4d
Show file tree
Hide file tree
Showing 15 changed files with 433 additions and 202 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""add active_segment_description and available_segment_descriptions
Revision ID: 30a5ba9d6453
Revises: 530f0663324e
Create Date: 2024-04-05 12:02:57.843244
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '30a5ba9d6453'
down_revision = '530f0663324e'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('action_event', schema=None) as batch_op:
batch_op.add_column(sa.Column('active_segment_description', sa.String(), nullable=True))
batch_op.add_column(sa.Column('available_segment_descriptions', sa.String(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('action_event', schema=None) as batch_op:
batch_op.drop_column('available_segment_descriptions')
batch_op.drop_column('active_segment_description')

# ### end Alembic commands ###
28 changes: 17 additions & 11 deletions openadapt/adapters/som.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import fire
from gradio_client import Client
import tempfile

from loguru import logger
from PIL import Image
import tempfile
import fire
import gradio_client

from openadapt import config


Expand All @@ -17,8 +19,8 @@ def save_image_to_temp_file(img: Image.Image) -> str:
return temp_file.name


def predict(server_url: str = config.SOM_SERVER_URL,
file_path: str = None,
def predict(file_path: str = None,
server_url: str = config.SOM_SERVER_URL,
granularity: float = 2.7,
segmentation_mode: str = "Automatic",
mask_alpha: float = 0.8,
Expand All @@ -39,10 +41,12 @@ def predict(server_url: str = config.SOM_SERVER_URL,
api_name (str): API endpoint name for inference.
"""

client = Client(server_url)
assert server_url, server_url
assert server_url.startswith("http"), server_url
client = gradio_client.Client(server_url)
result = client.predict(
{
"background": file_path,
"background": gradio_client.file(file_path),
},
granularity,
segmentation_mode,
Expand All @@ -56,18 +60,20 @@ def predict(server_url: str = config.SOM_SERVER_URL,
return result


def predict_for_image(image: Image.Image):
def fetch_segmented_image(image: Image.Image):
"""
Process an image directly using PIL.Image.Image object and predict using the Gradio client.
Args:
image (PIL.Image.Image): A PIL image object.
"""
img_temp_path = save_image_to_temp_file(image) # Save the image to a temp file
segmented_image = predict(file_path=img_temp_path) # Perform prediction
segmented_image_path = predict(file_path=img_temp_path) # Perform prediction
os.remove(img_temp_path) # Delete the temp file after prediction
return segmented_image
image = Image.open(segmented_image_path)
os.remove(segmented_image_path)
return image


if __name__ == "__main__":
fire.Fire(predict_for_image)
fire.Fire(predict)
2 changes: 1 addition & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"SOM_SERVER_URL": "<set in .env>",
"REPLICATE_API_TOKEN": "<set your api key in .env>",
"DEFAULT_ADAPTER": "openai",
"DEFAULT_SEGMENTATION_ADAPTER": "som", # "som" or "replicate"
"DEFAULT_SEGMENTATION_ADAPTER": "replicate", # "som" or "replicate"
"ANTHROPIC_API_KEY": "<set your api key in .env>",
"CACHE_DIR_PATH": ".cache",
"CACHE_ENABLED": True,
Expand Down
27 changes: 27 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class ActionEvent(db.Base):

__tablename__ = "action_event"

_segment_description_separator = ";"

id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
timestamp = sa.Column(ForceFloat)
Expand All @@ -89,6 +91,10 @@ class ActionEvent(db.Base):
mouse_y = sa.Column(sa.Numeric(asdecimal=False))
mouse_dx = sa.Column(sa.Numeric(asdecimal=False))
mouse_dy = sa.Column(sa.Numeric(asdecimal=False))
active_segment_description = sa.Column(sa.String)
_available_segment_descriptions = sa.Column(
"available_segment_descriptions", sa.String,
)
mouse_button_name = sa.Column(sa.String)
mouse_pressed = sa.Column(sa.Boolean)
key_name = sa.Column(sa.String)
Expand All @@ -100,6 +106,21 @@ class ActionEvent(db.Base):
parent_id = sa.Column(sa.Integer, sa.ForeignKey("action_event.id"))
element_state = sa.Column(sa.JSON)

@property
def available_segment_descriptions(self) -> list[str]:
if self._available_segment_descriptions:
return self._available_segment_descriptions.split(
self._segment_description_separator
)
else:
return []

@available_segment_descriptions.setter
def available_segment_descriptions(self, value: list[str]):
self._available_segment_descriptions = self._segment_description_separator.join(
value
)

children = sa.orm.relationship("ActionEvent")
# TODO: replacing the above line with the following two results in an error:
# AttributeError: 'list' object has no attribute '_sa_instance_state'
Expand Down Expand Up @@ -458,6 +479,12 @@ def crop_active_window(
self._image_history.append(self.image)
self._image = self._image.crop(box)

@property
def original_image(self) -> Image:
if self._image_history:
return self._image_history[0]
return self.image

def convert_binary_to_png(self, image_binary: bytes) -> Image:
"""Convert a binary image to a PNG image.
Expand Down
8 changes: 4 additions & 4 deletions openadapt/prompts/action.j2
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Here are all of the reduced actions in the recording:
{% for prompt_frame in prompt_frames %}
###
Action {{ prompt_frame['action_number'] }} of {{ num_actions }}:
{{ prompt_frame['action'] | orjson_to_json }}
{{ prompt_frame['action'] | orjson }}
Corresponding window:
{{ prompt_frame['window'] | orjson_to_json }}
{{ prompt_frame['window'] | orjson }}
Attached image {{ prompt_frame['screenshot_number'] }} of {{ num_images }}
Active Segment Description: {{ prompt_frame['active_segment_description'] }}
###
Expand All @@ -19,7 +19,7 @@ Here are all of the reduced actions in the recording:
Your job is to provide a (possibly modified) version of the reference action:

```json
{{ reference_action | orjson_to_json }}
{{ reference_action | orjson }}
```

{% if active_segmentation %}
Expand All @@ -33,7 +33,7 @@ Centroid: {{ centroid }}
The active window is:

```json
{{ active_window | orjson_to_json }}
{{ active_window | orjson }}
```

Image {{ num_images }} is the active screenshot.
Expand Down
19 changes: 19 additions & 0 deletions openadapt/prompts/apply_replay_instructions.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Consider the actions in the recording:

```json
{{ actions }}
```

Consider the user's replay instructions:
```text
{{ replay_instructions }}
```

Provide an updated list of actions that are modified such that replaying them will
accomplish the user's replay instructions.

Do NOT provide available_segment_descriptions in your response.

Respond with json and nothing else.

My career depends on this. Lives are at stake.
10 changes: 5 additions & 5 deletions openadapt/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def replay(
timestamp (str, optional): Timestamp of the recording to replay.
recording (Recording, optional): Recording to replay.
record (bool, optional): Flag indicating whether to record the replay.
instructions (str, optional): Natural language instructions to the model, e.g.
description of what are the parameters in the recording, and how the replay
should behave as a function of those parameters.
instructions (str, optional): Natural language instructions to the
model, e.g. description of what are the parameters in the recording, and
how the replay should behave as a function of those parameters.
Returns:
bool: True if replay was successful, None otherwise.
Expand Down Expand Up @@ -71,7 +71,7 @@ def replay(
strategy_class = strategy_class_by_name[strategy_name]
logger.info(f"{strategy_class=}")

strategy = strategy_class(recording)
strategy = strategy_class(recording, instructions)
logger.info(f"{strategy=}")

handler = None
Expand All @@ -87,7 +87,7 @@ def replay(
logger.info(f"{file_path=}")
handler = logger.add(open(file_path, "w"))
try:
strategy.run(instructions)
strategy.run()
except Exception as e:
logger.exception(e)
rval = False
Expand Down
10 changes: 4 additions & 6 deletions openadapt/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ class BaseReplayStrategy(ABC):
def __init__(
self,
recording: models.Recording,
replay_instructions: str | None = None,
max_frame_times: int = MAX_FRAME_TIMES,
) -> None:
"""Initialize the BaseReplayStrategy.
Args:
recording (models.Recording): The recording to replay.
replay_isntructions (str, optional): Natural language instructions
for how recording should be replayed.
max_frame_times (int): The maximum number of frame times to track.
"""
self.recording = recording
Expand All @@ -41,22 +44,18 @@ def __init__(
def get_next_action_event(
self,
screenshot: models.Screenshot,
instructions: str | None,
) -> models.ActionEvent:
"""Get the next action event based on the current screenshot.
Args:
screenshot (models.Screenshot): The current screenshot.
instructions (str, optional): Natural language instructions to the model,
e.g. description of what are the parameters in the recording, and how
the replay should behave as a function of those parameters.
Returns:
models.ActionEvent: The next action event.
"""
pass

def run(self, instructions: str | None) -> None:
def run(self) -> None:
"""Run the replay strategy."""
keyboard_controller = keyboard.Controller()
mouse_controller = mouse.Controller()
Expand All @@ -69,7 +68,6 @@ def run(self, instructions: str | None) -> None:
action_event = self.get_next_action_event(
screenshot,
window_event,
instructions,
)
except StopIteration:
break
Expand Down
5 changes: 4 additions & 1 deletion openadapt/strategies/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class DemoReplayStrategy(
def __init__(
self,
recording: Recording,
replay_instructions: str | None = None,
) -> None:
"""Initialize the DemoReplayStrategy.
Args:
recording (Recording): The recording to replay.
replay_instructions (str): Natural language instructions
for how recording should be replayed.
"""
super().__init__(recording)
super().__init__(recording, replay_instructions)
self.result_history = []
self.screenshots = get_screenshots(recording)
self.screenshot_idx = 0
Expand Down
3 changes: 3 additions & 0 deletions openadapt/strategies/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class NaiveReplayStrategy(strategies.base.BaseReplayStrategy):
def __init__(
self,
recording: models.Recording,
replay_instructions: str | None = None,
display_events: bool = DISPLAY_EVENTS,
replay_events: bool = REPLAY_EVENTS,
process_events: bool = PROCESS_EVENTS,
Expand All @@ -28,6 +29,8 @@ def __init__(
Args:
recording (models.Recording): The recording object.
replay_instructions (str): Natural language instructions
for how recording should be replayed.
display_events (bool): Flag indicating whether to display the events.
Defaults to False.
replay_events (bool): Flag indicating whether to replay the events.
Expand Down
5 changes: 4 additions & 1 deletion openadapt/strategies/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ class StatefulReplayStrategy(
def __init__(
self,
recording: models.Recording,
replay_instructions: str | None = None,
) -> None:
"""Initialize the StatefulReplayStrategy.
Args:
recording (models.Recording): The recording object.
replay_instructions (str): Natural language instructions
for how recording should be replayed.
"""
super().__init__(recording)
super().__init__(recording, replay_instructions)
self.recording_window_state_diffs = get_window_state_diffs(
recording.processed_action_events
)
Expand Down
Loading

0 comments on commit b7acb4d

Please sign in to comment.