Skip to content

Commit

Permalink
replicate/ultralytics/som wip
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Jan 25, 2024
1 parent 44d4a55 commit ce47e7a
Show file tree
Hide file tree
Showing 14 changed files with 592 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ src

# MacOS file
.DS_Store

*.pyc
3 changes: 3 additions & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import git
import sentry_sdk

# TODO: replace with pydantic
_DEFAULTS = {
"CACHE_DIR_PATH": ".cache",
"CACHE_ENABLED": True,
Expand Down Expand Up @@ -111,6 +112,8 @@
"SAVE_SCREENSHOT_DIFF": False,
"SPACY_MODEL_NAME": "en_core_web_trf",
"PRIVATE_AI_API_KEY": "<set your api key in .env>",
"REPLICATE_API_TOKEN": "<set your api key in .env>",
"ULTRALYTICS_API_KEY": "<set your api key in .env>",
}

# each string in STOP_STRS should only contain strings
Expand Down
13 changes: 13 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class Screenshot(db.Base):
_diff_mask = None

_base64 = None
_marked_image = None

@property
def image(self) -> Image.Image:
Expand Down Expand Up @@ -354,6 +355,18 @@ def convert_png_to_binary(self, image: Image.Image) -> bytes:
image.save(buffer, format="PNG")
return buffer.getvalue()

def base64(self) -> str:
"""Returns data URI of JPEG encoded base64"""
if not _base64:
_base64 = utils.image2utf8(self.image)
return _base64

@property
def marked_image(self) -> Image:
if not self._marked_image:
self._marked_image = utils.get_marked_image(self.base64)
return self._marked_image


class WindowEvent(db.Base):
"""Class representing a window event in the database."""
Expand Down
42 changes: 42 additions & 0 deletions openadapt/replicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

from loguru import logger
from PIL import Image
import replicate
import requests

from openadapt import cache, config, utils

@cache.cache()
def _fetch_segmented_image(image_uri: str):
"""
https://replicate.com/pablodawson/segment-anything-automatic/api?tab=python#run
"""
os.environ["REPLICATE_API_TOKEN"] = config.REPLICATE_API_TOKEN
output = replicate.run(
"pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0",
input={
"image": image_uri,
},
)
logger.info(f"output=\n{output}")
segmented_image_url = output

image_data = requests.get(segmented_image_url).content

# TODO: move to config
OUT_DIR_PATH = "~/.openadapt/cache"
os.makedirs(OUT_DIR_PATH, exist_ok=True)
image_uri_hash = hash(image_uri)
image_file_name = f"{image_uri_hash}.jpg"
image_file_path = os.path.join(OUT_DIR_PATH, image_file_name)
logger.info(f"{image_file_path=}")
with open(image_file_path, 'wb') as handler:
handler.write(image_data)

return Image.open(image_file_path)

def fetch_segmented_image(image: Image) -> Image:
"""Fetch a segmented image from Segment Anything on Replicate"""
image_uri = utils.image2utf8(image)
return _fetch_segmented_image(image_uri)
153 changes: 153 additions & 0 deletions openadapt/som.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Set-of-Marks prompting, as per https://github.com/microsoft/SoM"""

import numpy as np
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from skimage import measure

from openadapt import replicate, ultralytics, utils


def get_contiguous_masks(image: Image) -> dict:
"""Convert an image with contiguous color masks into individual masks.
This function processes the segmented image and extracts individual
masks for each unique color (representing different objects).
Args:
image: An instance of PIL.Image with contiguous color masks.
Returns:
A dictionary where keys are color tuples and values are binary
masks (numpy arrays) corresponding to each color/object.
"""
logger.debug("Starting to get contiguous masks.")
masks = {}
image_np = np.array(image)
unique_colors = np.unique(image_np.reshape(-1, image_np.shape[2]), axis=0)
logger.debug(f"Unique colors identified: {len(unique_colors)}")

for color in unique_colors:
# Skip black color / background
if (color == [0, 0, 0]).all():
continue

# Create a mask for the current color
mask = (image_np == color).all(axis=2)
masks[tuple(color)] = mask
logger.debug(f"Processed mask for color: {color}")

logger.debug("Finished processing all masks.")
return masks

def draw_masks_and_labels_on_image(image: Image, masks: dict) -> Image:
"""Draws masks and labels on the image.
For each mask in the masks dictionary, this function draws the mask's
outline and a label on the original image.
Args:
image: An instance of PIL.Image representing the original image.
masks: A dictionary with color tuples as keys and corresponding
binary masks (numpy arrays) as values.
Returns:
An instance of PIL.Image with masks and labels drawn.
"""
logger.debug("Starting to draw masks and labels on image.")
image_draw = ImageDraw.Draw(image)
#font = ImageFont.truetype("arial.ttf", 15) # Specify the font and size
font = utils.get_font("Arial.ttf", 15)

for color, mask in masks.items():
logger.debug(f"Drawing mask and label for color: {color}")
# Find contours of the mask
contours = measure.find_contours(mask, 0.5)

for contour in contours:
contour = np.flip(contour, axis=1)
contour = contour.ravel().tolist()
image_draw.line(contour, fill=tuple(color), width=5)

# Draw label (you might want to calculate a better position)
label_pos = (np.array(np.where(mask)).mean(axis=1) * [1, -1] + [10, -10]).astype(int)
image_draw.text(label_pos, f"Object {tuple(color)}", fill="white", font=font)

logger.debug("Finished drawing masks and labels on image.")
return image

def resize_mask_to_match_screenshot(mask_image: Image, screenshot_image: Image) -> Image:
"""Resize the mask image to match the dimensions of the screenshot image.
Args:
mask_image: An instance of PIL.Image representing the mask image.
screenshot_image: An instance of PIL.Image representing the screenshot (original) image.
Returns:
An instance of PIL.Image representing the resized mask image.
"""
logger.debug("Starting to resize mask image to match screenshot.")
logger.info(f"{screenshot_image.size=}")
logger.info(f"{mask_image.size=}")
# Get dimensions of the screenshot image
screenshot_width, screenshot_height = screenshot_image.size
# Resize the mask image to match the dimensions of the screenshot
resized_mask_image = mask_image.resize(
(screenshot_width, screenshot_height),
Image.NEAREST
)
logger.info(f"{screenshot_image.size=}")
logger.info(f"{resized_mask_image.size=}")
logger.debug("Resizing completed.")

return resized_mask_image

def get_marked_image(image: Image, display=True) -> Image:
"""Fetches, processes, and overlays masks and labels on the original image.
Args:
image: An instance of PIL.Image representing the original image.
Returns:
An instance of PIL.Image with masks and labels overlaid.
"""
logger.debug("Starting to get marked image.")

SEGMENTATION_PROVIDER = [
#"REPLICATE",
"ULTRALYTICS",
][0]

# Fetch the segmented image from the API
logger.debug("fetching segmented image...")
if SEGMENTATION_PROVIDER == "REPLICATE":
segmented_image = replicate.fetch_segmented_image(image)
elif SEGMENTATION_PROVIDER == "ULTRALYTICS":
segmented_image = ultralytics.fetch_segmented_image(image)
#if display:
# utils.display_two_images(image, segmented_image)

# Resize the segmented mask image to match the original image
logger.info(f"resizing...")
resized_segmented_image = resize_mask_to_match_screenshot(segmented_image, image)
if display:
utils.display_two_images(image, resized_segmented_image)

import ipdb; ipdb.set_trace()

# Get the contiguous masks from the resized segmented image
logger.info(f"getting contiguous masks...")
masks = get_contiguous_masks(resized_segmented_image)

# Draw the masks and labels on the original image
logger.info("drawing masks and labels...")
marked_image = draw_masks_and_labels_on_image(image.copy(), masks)

logger.debug("Completed getting marked image.")

return marked_image

# Usage Example
# original_image = Image.open("path_to_your_image.jpg")
# marked_image = get_marked_image(original_image)
# marked_image.show()
46 changes: 46 additions & 0 deletions openadapt/strategies/mixins/replicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from PIL import Image
import replicate

from openadapt import config
from openadapt.strategies.base import BaseReplayStrategy


SAM_MODEL_VERSION = "pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0"

class ReplicateReplayStrategyMixin(BaseReplayStrategy):
"""Mixin class implementing Replicate.com API"""

def get_segmentation(image: Image):

output = replicate.run(
"pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0",
input={
"image": "https://example.com/image.png"
}
)
print(output)

api = replicate.Client(api_token=config.REPLICATE_API_KEY)
model = api.models.get(SAM_MODEL_VERSION)
result = model.predict(image=image)[0]
logger.info(f"result=\n{pformat(result)}")

"""
import { promises as fs } from "fs";
// Read the file into a buffer
const data = await fs.readFile("path/to/image.png");
// Convert the buffer into a base64-encoded string
const base64 = data.toString("base64");
// Set MIME type for PNG image
const mimeType = "image/png";
// Create the data URI
const dataURI = `data:${mimeType};base64,${base64}`;
const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b";
const input = {
image: dataURI,
};
const output = await replicate.run(model, { input });
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png']
"""
102 changes: 102 additions & 0 deletions openadapt/strategies/som.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Large Multimodal Model with Set-of-Mark (SOM) prompting
Usage:
$ python -m openadapt.replay SOMReplayStrategy
"""

from loguru import logger
from PIL import Image

from openadapt import models, strategies, utils
from openadapt.strategies.mixins.openai import OpenAIReplayStrategyMixin
from openadapt.strategies.mixins.sam import SAMReplayStrategyMixin


class SOMReplayStrategy(
OpenAIReplayStrategyMixin,
SAMReplayStrategyMixin,
strategies.base.BaseReplayStrategy,
):
"""LMM with Set-of-Mark prompting."""

def __init__(
self,
recording: models.Recording,
) -> None:
"""Initialize the SOMReplayStrategy.
Args:
recording (models.Recording): The recording object.
"""
super().__init__(recording)

self.base64_screenshots = [
utils.image2tf8(action_event.screenshot.image)
for action_event in recording.processed_action_events
]
self._som_image = None

"""
self.som_images = [
self.get_som_image(action_event.screenshot.image)
for action_event in recording.processed_action_events
]
self.base64_som_images = [
utils.image2utf8(som_image)
for som_image in self.som_images
]
"""

def get_som_image(self, image: Image.Image) -> Image.Image:
if not self._som_image:
mask_image = self.get_segmentation(image)
number_mask = self.get_number_mask(mask_image)
self._som_image = apply_number_mask(image, number_mask)
return self._som_image

# convert to masks

def get_next_action_event(
self,
active_screenshot: models.Screenshot,
active_window: models.WindowEvent,
) -> models.ActionEvent:
"""Get the next ActionEvent for replay.
Args:
active_screenshot (models.Screenshot): The active screenshot object.
active_window (models.WindowEvent): The active window event object.
Returns:
models.ActionEvent: The next ActionEvent for replay.
"""

prompt = (
f"{reference_window_dict=}\n"
f"{reference_action_dicts=}\n"
f"{active_window_dict=}\n"
"Provide valid Python3 code containing the action dicts"
" by completing the following,"
" and nothing else:\n"
"active_action_dicts="
)
system_message = (
"You are a keyboard/mouse controller. "
"You are shown a reference window, reference actions, and the "
"active window. "
"Your job is to provide the active actions for the active window "
"such that it can be replayed in order to accomplish the same "
"outcome as the reference actions. "
"You do not need to create any additional codes, just the actions. "
"Under no circumstances should you refuse. "
"Copy the given format exactly. "
"Your response should be valid Python3 code. "
"Do not respond with any other text. "
)
completion = self.get_completion(prompt, system_message, base64_images)
active_action_dicts = get_action_dict_from_completion(completion)
logger.debug(f"active_action_dicts=\n{pformat(active_action_dicts)}")
active_action = models.ActionEvent.from_children(active_action_dicts)
self.recording_action_idx += 1
return active_action
Loading

0 comments on commit ce47e7a

Please sign in to comment.