forked from OpenAdaptAI/OpenAdapt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement model cursor for visual feedback (OpenAdaptAI#760)
- Add CursorReplayStrategy class with visual feedback - Implement self-correction mechanism - Add comprehensive test suite - Include demo with examples
- Loading branch information
Showing
13 changed files
with
276 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import cv2 | ||
from replay_strategies.cursor_replay import CursorReplayStrategy | ||
|
||
def create_test_image(): | ||
"""Create a test image with various UI elements.""" | ||
# Create a 400x300 image with a white background | ||
image = np.ones((300, 400, 3), dtype=np.uint8) * 255 | ||
|
||
# Add UI elements | ||
# Button | ||
cv2.rectangle(image, (50, 50), (150, 100), (200, 200, 200), -1) | ||
cv2.rectangle(image, (50, 50), (150, 100), (100, 100, 100), 2) | ||
cv2.putText(image, "Button", (70, 85), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) | ||
|
||
# Checkbox | ||
cv2.rectangle(image, (200, 50), (220, 70), (255, 255, 255), -1) | ||
cv2.rectangle(image, (200, 50), (220, 70), (100, 100, 100), 1) | ||
|
||
# Text input field | ||
cv2.rectangle(image, (50, 150), (350, 180), (255, 255, 255), -1) | ||
cv2.rectangle(image, (50, 150), (350, 180), (100, 100, 100), 1) | ||
cv2.putText(image, "Text input", (60, 170), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (150, 150, 150), 1) | ||
|
||
# Dropdown | ||
cv2.rectangle(image, (50, 200), (150, 230), (240, 240, 240), -1) | ||
cv2.rectangle(image, (50, 200), (150, 230), (100, 100, 100), 1) | ||
cv2.putText(image, "▼", (130, 220), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) | ||
|
||
return image | ||
|
||
def test_cursor_strategy(image, strategy, name, coords): | ||
"""Test the cursor strategy with given coordinates.""" | ||
# Process coordinates | ||
processed_coords = strategy.process_target(image, coords) | ||
print(f"\n{name}:") | ||
print(f"Original coordinates: {coords}") | ||
print(f"Processed coordinates: {processed_coords}") | ||
|
||
# Create visualization | ||
visualization = strategy.visualize_target(image, processed_coords) | ||
filename = f'visualization_{name.lower().replace(" ", "_")}.png' | ||
cv2.imwrite(filename, visualization) | ||
print(f"Created {filename}") | ||
|
||
return visualization | ||
|
||
def main(): | ||
# Create test image | ||
test_image = create_test_image() | ||
cv2.imwrite('original.png', test_image) | ||
print("Created original.png") | ||
|
||
# Initialize strategy with self-correction | ||
strategy = CursorReplayStrategy( | ||
dot_radius=6, | ||
dot_color=(0, 0, 255), # Red in BGR | ||
enable_self_correction=True, | ||
confidence_threshold=0.7 | ||
) | ||
|
||
# Test cases | ||
test_cases = [ | ||
("Button Click", (100, 80)), # Near button center | ||
("Checkbox", (210, 55)), # Slightly off checkbox | ||
("Text Input", (200, 165)), # Text input field | ||
("Dropdown", (140, 215)), # Dropdown arrow | ||
("Out of Bounds", (500, 500)) # Out of bounds | ||
] | ||
|
||
# Run tests | ||
for name, coords in test_cases: | ||
test_cursor_strategy(test_image, strategy, name, coords) | ||
|
||
print("\nAll visualizations have been created!") | ||
|
||
if __name__ == "__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,4 @@ | ||
"""Package containing different replay strategies. | ||
from .base import VanillaReplayStrategy | ||
from .cursor_replay import CursorReplayStrategy | ||
|
||
Module: __init__.py | ||
""" | ||
|
||
# flake8: noqa | ||
|
||
from openadapt.strategies.base import BaseReplayStrategy | ||
from openadapt.strategies.visual_browser import VisualBrowserReplayStrategy | ||
|
||
# disabled because importing is expensive | ||
# from openadapt.strategies.demo import DemoReplayStrategy | ||
from openadapt.strategies.naive import NaiveReplayStrategy | ||
from openadapt.strategies.segment import SegmentReplayStrategy | ||
from openadapt.strategies.stateful import StatefulReplayStrategy | ||
from openadapt.strategies.vanilla import VanillaReplayStrategy | ||
from openadapt.strategies.visual import VisualReplayStrategy | ||
|
||
# add more strategies here | ||
__all__ = ['VanillaReplayStrategy', 'CursorReplayStrategy'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import numpy as np | ||
import cv2 | ||
from typing import Tuple, Optional | ||
from .base import VanillaReplayStrategy | ||
|
||
class CursorReplayStrategy(VanillaReplayStrategy): | ||
def __init__(self, | ||
dot_radius: int = 5, | ||
dot_color: Tuple[int, int, int] = (0, 0, 255), | ||
enable_self_correction: bool = False, | ||
confidence_threshold: float = 0.7): | ||
"""Initialize the CursorReplayStrategy. | ||
Args: | ||
dot_radius: Radius of the cursor dot in pixels | ||
dot_color: Color of the dot in BGR format | ||
enable_self_correction: Whether to enable visual self-correction | ||
confidence_threshold: Threshold for self-correction confidence | ||
""" | ||
self.dot_radius = dot_radius | ||
self.dot_color = dot_color | ||
self.enable_self_correction = enable_self_correction | ||
self.confidence_threshold = confidence_threshold | ||
self.last_correction = None | ||
|
||
def _analyze_target_region(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> Optional[Tuple[int, int]]: | ||
"""Analyze the region around the target coordinates for better positioning. | ||
Uses edge detection and contour analysis to find better click targets. | ||
""" | ||
x, y = target_coords | ||
region_size = self.dot_radius * 4 | ||
|
||
# Extract region of interest | ||
x1 = max(0, x - region_size) | ||
y1 = max(0, y - region_size) | ||
x2 = min(screenshot.shape[1], x + region_size) | ||
y2 = min(screenshot.shape[0], y + region_size) | ||
roi = screenshot[y1:y2, x1:x2] | ||
|
||
if roi.size == 0: | ||
return None | ||
|
||
# Convert to grayscale | ||
gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) | ||
|
||
# Edge detection | ||
edges = cv2.Canny(gray, 50, 150) | ||
|
||
# Find contours | ||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | ||
|
||
if not contours: | ||
return None | ||
|
||
# Find the largest contour near the center | ||
center_x, center_y = region_size, region_size | ||
best_contour = None | ||
min_dist = float('inf') | ||
|
||
for contour in contours: | ||
M = cv2.moments(contour) | ||
if M["m00"] == 0: | ||
continue | ||
|
||
cx = int(M["m10"] / M["m00"]) | ||
cy = int(M["m01"] / M["m00"]) | ||
|
||
dist = np.sqrt((cx - center_x)**2 + (cy - center_y)**2) | ||
if dist < min_dist: | ||
min_dist = dist | ||
best_contour = contour | ||
|
||
if best_contour is None: | ||
return None | ||
|
||
# Get the center of the best contour | ||
M = cv2.moments(best_contour) | ||
cx = int(M["m10"] / M["m00"]) + x1 | ||
cy = int(M["m01"] / M["m00"]) + y1 | ||
|
||
return (cx, cy) | ||
|
||
def process_target(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> Tuple[int, int]: | ||
"""Process the target coordinates with optional self-correction. | ||
Args: | ||
screenshot: The current screenshot | ||
target_coords: The original target coordinates | ||
Returns: | ||
Tuple[int, int]: The processed target coordinates | ||
""" | ||
# Ensure coordinates are within bounds | ||
x, y = target_coords | ||
height, width = screenshot.shape[:2] | ||
x = min(max(0, x), width - 1) | ||
y = min(max(0, y), height - 1) | ||
|
||
if not self.enable_self_correction: | ||
return (x, y) | ||
|
||
# Analyze the target region for better positioning | ||
corrected_coords = self._analyze_target_region(screenshot, (x, y)) | ||
|
||
if corrected_coords is not None: | ||
# Store the correction for visualization | ||
self.last_correction = { | ||
'original': (x, y), | ||
'corrected': corrected_coords | ||
} | ||
return corrected_coords | ||
|
||
return (x, y) | ||
|
||
def visualize_target(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> np.ndarray: | ||
"""Visualize the target with a red dot and optional correction path. | ||
Args: | ||
screenshot: The current screenshot | ||
target_coords: The target coordinates | ||
Returns: | ||
np.ndarray: The screenshot with visualization overlay | ||
""" | ||
visualization = screenshot.copy() | ||
|
||
# Draw correction path if available | ||
if self.enable_self_correction and self.last_correction is not None: | ||
orig = self.last_correction['original'] | ||
corr = self.last_correction['corrected'] | ||
|
||
# Draw path from original to corrected position | ||
cv2.line(visualization, orig, corr, (0, 255, 0), 1) | ||
# Draw original position (smaller, yellow dot) | ||
cv2.circle(visualization, orig, self.dot_radius-2, (0, 255, 255), -1) | ||
|
||
# Draw target dot with white outline for visibility | ||
cv2.circle(visualization, target_coords, self.dot_radius+2, (255, 255, 255), -1) | ||
cv2.circle(visualization, target_coords, self.dot_radius, self.dot_color, -1) | ||
|
||
return visualization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import numpy as np | ||
from replay_strategies.cursor_replay import CursorReplayStrategy | ||
|
||
@pytest.fixture | ||
def strategy(): | ||
return CursorReplayStrategy() | ||
|
||
@pytest.fixture | ||
def test_image(): | ||
return np.zeros((100, 100, 3), dtype=np.uint8) | ||
|
||
def test_initialization(): | ||
strategy = CursorReplayStrategy( | ||
dot_radius=10, | ||
dot_color=(255, 0, 0), | ||
enable_self_correction=True | ||
) | ||
assert strategy.dot_radius == 10 | ||
assert strategy.dot_color == (255, 0, 0) | ||
assert strategy.enable_self_correction is True | ||
|
||
def test_process_target_without_correction(strategy, test_image): | ||
coords = (50, 50) | ||
processed_coords = strategy.process_target(test_image, coords) | ||
assert processed_coords == coords | ||
|
||
def test_process_target_bounds_checking(strategy, test_image): | ||
# Test coordinates outside image bounds | ||
coords = (150, 150) | ||
processed_coords = strategy.process_target(test_image, coords) | ||
assert processed_coords[0] < 100 | ||
assert processed_coords[1] < 100 | ||
|
||
def test_visualize_target(strategy, test_image): | ||
coords = (50, 50) | ||
visualization = strategy.visualize_target(test_image, coords) | ||
|
||
# Check that visualization has changed the image | ||
assert not np.array_equal(visualization, test_image) | ||
|
||
# Check that the dot was drawn (should have non-zero pixels) | ||
assert np.sum(visualization) > 0 | ||
|
||
def test_self_correction_enabled(): | ||
strategy = CursorReplayStrategy(enable_self_correction=True) | ||
test_image = np.zeros((100, 100, 3), dtype=np.uint8) | ||
coords = (50, 50) | ||
|
||
processed_coords = strategy.process_target(test_image, coords) | ||
assert isinstance(processed_coords, tuple) | ||
assert len(processed_coords) == 2 | ||
assert all(isinstance(x, (int, np.integer)) for x in processed_coords) |