-
-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
592 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,5 @@ src | |
|
||
# MacOS file | ||
.DS_Store | ||
|
||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
|
||
from loguru import logger | ||
from PIL import Image | ||
import replicate | ||
import requests | ||
|
||
from openadapt import cache, config, utils | ||
|
||
@cache.cache() | ||
def _fetch_segmented_image(image_uri: str): | ||
""" | ||
https://replicate.com/pablodawson/segment-anything-automatic/api?tab=python#run | ||
""" | ||
os.environ["REPLICATE_API_TOKEN"] = config.REPLICATE_API_TOKEN | ||
output = replicate.run( | ||
"pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0", | ||
input={ | ||
"image": image_uri, | ||
}, | ||
) | ||
logger.info(f"output=\n{output}") | ||
segmented_image_url = output | ||
|
||
image_data = requests.get(segmented_image_url).content | ||
|
||
# TODO: move to config | ||
OUT_DIR_PATH = "~/.openadapt/cache" | ||
os.makedirs(OUT_DIR_PATH, exist_ok=True) | ||
image_uri_hash = hash(image_uri) | ||
image_file_name = f"{image_uri_hash}.jpg" | ||
image_file_path = os.path.join(OUT_DIR_PATH, image_file_name) | ||
logger.info(f"{image_file_path=}") | ||
with open(image_file_path, 'wb') as handler: | ||
handler.write(image_data) | ||
|
||
return Image.open(image_file_path) | ||
|
||
def fetch_segmented_image(image: Image) -> Image: | ||
"""Fetch a segmented image from Segment Anything on Replicate""" | ||
image_uri = utils.image2utf8(image) | ||
return _fetch_segmented_image(image_uri) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""Set-of-Marks prompting, as per https://github.com/microsoft/SoM""" | ||
|
||
import numpy as np | ||
from loguru import logger | ||
from PIL import Image, ImageDraw, ImageFont | ||
from skimage import measure | ||
|
||
from openadapt import replicate, ultralytics, utils | ||
|
||
|
||
def get_contiguous_masks(image: Image) -> dict: | ||
"""Convert an image with contiguous color masks into individual masks. | ||
This function processes the segmented image and extracts individual | ||
masks for each unique color (representing different objects). | ||
Args: | ||
image: An instance of PIL.Image with contiguous color masks. | ||
Returns: | ||
A dictionary where keys are color tuples and values are binary | ||
masks (numpy arrays) corresponding to each color/object. | ||
""" | ||
logger.debug("Starting to get contiguous masks.") | ||
masks = {} | ||
image_np = np.array(image) | ||
unique_colors = np.unique(image_np.reshape(-1, image_np.shape[2]), axis=0) | ||
logger.debug(f"Unique colors identified: {len(unique_colors)}") | ||
|
||
for color in unique_colors: | ||
# Skip black color / background | ||
if (color == [0, 0, 0]).all(): | ||
continue | ||
|
||
# Create a mask for the current color | ||
mask = (image_np == color).all(axis=2) | ||
masks[tuple(color)] = mask | ||
logger.debug(f"Processed mask for color: {color}") | ||
|
||
logger.debug("Finished processing all masks.") | ||
return masks | ||
|
||
def draw_masks_and_labels_on_image(image: Image, masks: dict) -> Image: | ||
"""Draws masks and labels on the image. | ||
For each mask in the masks dictionary, this function draws the mask's | ||
outline and a label on the original image. | ||
Args: | ||
image: An instance of PIL.Image representing the original image. | ||
masks: A dictionary with color tuples as keys and corresponding | ||
binary masks (numpy arrays) as values. | ||
Returns: | ||
An instance of PIL.Image with masks and labels drawn. | ||
""" | ||
logger.debug("Starting to draw masks and labels on image.") | ||
image_draw = ImageDraw.Draw(image) | ||
#font = ImageFont.truetype("arial.ttf", 15) # Specify the font and size | ||
font = utils.get_font("Arial.ttf", 15) | ||
|
||
for color, mask in masks.items(): | ||
logger.debug(f"Drawing mask and label for color: {color}") | ||
# Find contours of the mask | ||
contours = measure.find_contours(mask, 0.5) | ||
|
||
for contour in contours: | ||
contour = np.flip(contour, axis=1) | ||
contour = contour.ravel().tolist() | ||
image_draw.line(contour, fill=tuple(color), width=5) | ||
|
||
# Draw label (you might want to calculate a better position) | ||
label_pos = (np.array(np.where(mask)).mean(axis=1) * [1, -1] + [10, -10]).astype(int) | ||
image_draw.text(label_pos, f"Object {tuple(color)}", fill="white", font=font) | ||
|
||
logger.debug("Finished drawing masks and labels on image.") | ||
return image | ||
|
||
def resize_mask_to_match_screenshot(mask_image: Image, screenshot_image: Image) -> Image: | ||
"""Resize the mask image to match the dimensions of the screenshot image. | ||
Args: | ||
mask_image: An instance of PIL.Image representing the mask image. | ||
screenshot_image: An instance of PIL.Image representing the screenshot (original) image. | ||
Returns: | ||
An instance of PIL.Image representing the resized mask image. | ||
""" | ||
logger.debug("Starting to resize mask image to match screenshot.") | ||
logger.info(f"{screenshot_image.size=}") | ||
logger.info(f"{mask_image.size=}") | ||
# Get dimensions of the screenshot image | ||
screenshot_width, screenshot_height = screenshot_image.size | ||
# Resize the mask image to match the dimensions of the screenshot | ||
resized_mask_image = mask_image.resize( | ||
(screenshot_width, screenshot_height), | ||
Image.NEAREST | ||
) | ||
logger.info(f"{screenshot_image.size=}") | ||
logger.info(f"{resized_mask_image.size=}") | ||
logger.debug("Resizing completed.") | ||
|
||
return resized_mask_image | ||
|
||
def get_marked_image(image: Image, display=True) -> Image: | ||
"""Fetches, processes, and overlays masks and labels on the original image. | ||
Args: | ||
image: An instance of PIL.Image representing the original image. | ||
Returns: | ||
An instance of PIL.Image with masks and labels overlaid. | ||
""" | ||
logger.debug("Starting to get marked image.") | ||
|
||
SEGMENTATION_PROVIDER = [ | ||
#"REPLICATE", | ||
"ULTRALYTICS", | ||
][0] | ||
|
||
# Fetch the segmented image from the API | ||
logger.debug("fetching segmented image...") | ||
if SEGMENTATION_PROVIDER == "REPLICATE": | ||
segmented_image = replicate.fetch_segmented_image(image) | ||
elif SEGMENTATION_PROVIDER == "ULTRALYTICS": | ||
segmented_image = ultralytics.fetch_segmented_image(image) | ||
#if display: | ||
# utils.display_two_images(image, segmented_image) | ||
|
||
# Resize the segmented mask image to match the original image | ||
logger.info(f"resizing...") | ||
resized_segmented_image = resize_mask_to_match_screenshot(segmented_image, image) | ||
if display: | ||
utils.display_two_images(image, resized_segmented_image) | ||
|
||
import ipdb; ipdb.set_trace() | ||
|
||
# Get the contiguous masks from the resized segmented image | ||
logger.info(f"getting contiguous masks...") | ||
masks = get_contiguous_masks(resized_segmented_image) | ||
|
||
# Draw the masks and labels on the original image | ||
logger.info("drawing masks and labels...") | ||
marked_image = draw_masks_and_labels_on_image(image.copy(), masks) | ||
|
||
logger.debug("Completed getting marked image.") | ||
|
||
return marked_image | ||
|
||
# Usage Example | ||
# original_image = Image.open("path_to_your_image.jpg") | ||
# marked_image = get_marked_image(original_image) | ||
# marked_image.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from PIL import Image | ||
import replicate | ||
|
||
from openadapt import config | ||
from openadapt.strategies.base import BaseReplayStrategy | ||
|
||
|
||
SAM_MODEL_VERSION = "pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0" | ||
|
||
class ReplicateReplayStrategyMixin(BaseReplayStrategy): | ||
"""Mixin class implementing Replicate.com API""" | ||
|
||
def get_segmentation(image: Image): | ||
|
||
output = replicate.run( | ||
"pablodawson/segment-anything-automatic:14fbb04535964b3d0c7fad03bb4ed272130f15b956cbedb7b2f20b5b8a2dbaa0", | ||
input={ | ||
"image": "https://example.com/image.png" | ||
} | ||
) | ||
print(output) | ||
|
||
api = replicate.Client(api_token=config.REPLICATE_API_KEY) | ||
model = api.models.get(SAM_MODEL_VERSION) | ||
result = model.predict(image=image)[0] | ||
logger.info(f"result=\n{pformat(result)}") | ||
|
||
""" | ||
import { promises as fs } from "fs"; | ||
// Read the file into a buffer | ||
const data = await fs.readFile("path/to/image.png"); | ||
// Convert the buffer into a base64-encoded string | ||
const base64 = data.toString("base64"); | ||
// Set MIME type for PNG image | ||
const mimeType = "image/png"; | ||
// Create the data URI | ||
const dataURI = `data:${mimeType};base64,${base64}`; | ||
const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b"; | ||
const input = { | ||
image: dataURI, | ||
}; | ||
const output = await replicate.run(model, { input }); | ||
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png'] | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""Large Multimodal Model with Set-of-Mark (SOM) prompting | ||
Usage: | ||
$ python -m openadapt.replay SOMReplayStrategy | ||
""" | ||
|
||
from loguru import logger | ||
from PIL import Image | ||
|
||
from openadapt import models, strategies, utils | ||
from openadapt.strategies.mixins.openai import OpenAIReplayStrategyMixin | ||
from openadapt.strategies.mixins.sam import SAMReplayStrategyMixin | ||
|
||
|
||
class SOMReplayStrategy( | ||
OpenAIReplayStrategyMixin, | ||
SAMReplayStrategyMixin, | ||
strategies.base.BaseReplayStrategy, | ||
): | ||
"""LMM with Set-of-Mark prompting.""" | ||
|
||
def __init__( | ||
self, | ||
recording: models.Recording, | ||
) -> None: | ||
"""Initialize the SOMReplayStrategy. | ||
Args: | ||
recording (models.Recording): The recording object. | ||
""" | ||
super().__init__(recording) | ||
|
||
self.base64_screenshots = [ | ||
utils.image2tf8(action_event.screenshot.image) | ||
for action_event in recording.processed_action_events | ||
] | ||
self._som_image = None | ||
|
||
""" | ||
self.som_images = [ | ||
self.get_som_image(action_event.screenshot.image) | ||
for action_event in recording.processed_action_events | ||
] | ||
self.base64_som_images = [ | ||
utils.image2utf8(som_image) | ||
for som_image in self.som_images | ||
] | ||
""" | ||
|
||
def get_som_image(self, image: Image.Image) -> Image.Image: | ||
if not self._som_image: | ||
mask_image = self.get_segmentation(image) | ||
number_mask = self.get_number_mask(mask_image) | ||
self._som_image = apply_number_mask(image, number_mask) | ||
return self._som_image | ||
|
||
# convert to masks | ||
|
||
def get_next_action_event( | ||
self, | ||
active_screenshot: models.Screenshot, | ||
active_window: models.WindowEvent, | ||
) -> models.ActionEvent: | ||
"""Get the next ActionEvent for replay. | ||
Args: | ||
active_screenshot (models.Screenshot): The active screenshot object. | ||
active_window (models.WindowEvent): The active window event object. | ||
Returns: | ||
models.ActionEvent: The next ActionEvent for replay. | ||
""" | ||
|
||
prompt = ( | ||
f"{reference_window_dict=}\n" | ||
f"{reference_action_dicts=}\n" | ||
f"{active_window_dict=}\n" | ||
"Provide valid Python3 code containing the action dicts" | ||
" by completing the following," | ||
" and nothing else:\n" | ||
"active_action_dicts=" | ||
) | ||
system_message = ( | ||
"You are a keyboard/mouse controller. " | ||
"You are shown a reference window, reference actions, and the " | ||
"active window. " | ||
"Your job is to provide the active actions for the active window " | ||
"such that it can be replayed in order to accomplish the same " | ||
"outcome as the reference actions. " | ||
"You do not need to create any additional codes, just the actions. " | ||
"Under no circumstances should you refuse. " | ||
"Copy the given format exactly. " | ||
"Your response should be valid Python3 code. " | ||
"Do not respond with any other text. " | ||
) | ||
completion = self.get_completion(prompt, system_message, base64_images) | ||
active_action_dicts = get_action_dict_from_completion(completion) | ||
logger.debug(f"active_action_dicts=\n{pformat(active_action_dicts)}") | ||
active_action = models.ActionEvent.from_children(active_action_dicts) | ||
self.recording_action_idx += 1 | ||
return active_action |
Oops, something went wrong.