Skip to content

Commit

Permalink
Implement model cursor for visual feedback (OpenAdaptAI#760)
Browse files Browse the repository at this point in the history
- Add CursorReplayStrategy class with visual feedback
- Implement self-correction mechanism
- Add comprehensive test suite
- Include demo with examples
  • Loading branch information
Sincedai1 committed Feb 2, 2025
1 parent acdbb7b commit 3482e8e
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 18 deletions.
78 changes: 78 additions & 0 deletions examples/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import cv2
from replay_strategies.cursor_replay import CursorReplayStrategy

def create_test_image():
"""Create a test image with various UI elements."""
# Create a 400x300 image with a white background
image = np.ones((300, 400, 3), dtype=np.uint8) * 255

# Add UI elements
# Button
cv2.rectangle(image, (50, 50), (150, 100), (200, 200, 200), -1)
cv2.rectangle(image, (50, 50), (150, 100), (100, 100, 100), 2)
cv2.putText(image, "Button", (70, 85), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

# Checkbox
cv2.rectangle(image, (200, 50), (220, 70), (255, 255, 255), -1)
cv2.rectangle(image, (200, 50), (220, 70), (100, 100, 100), 1)

# Text input field
cv2.rectangle(image, (50, 150), (350, 180), (255, 255, 255), -1)
cv2.rectangle(image, (50, 150), (350, 180), (100, 100, 100), 1)
cv2.putText(image, "Text input", (60, 170), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (150, 150, 150), 1)

# Dropdown
cv2.rectangle(image, (50, 200), (150, 230), (240, 240, 240), -1)
cv2.rectangle(image, (50, 200), (150, 230), (100, 100, 100), 1)
cv2.putText(image, "▼", (130, 220), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

return image

def test_cursor_strategy(image, strategy, name, coords):
"""Test the cursor strategy with given coordinates."""
# Process coordinates
processed_coords = strategy.process_target(image, coords)
print(f"\n{name}:")
print(f"Original coordinates: {coords}")
print(f"Processed coordinates: {processed_coords}")

# Create visualization
visualization = strategy.visualize_target(image, processed_coords)
filename = f'visualization_{name.lower().replace(" ", "_")}.png'
cv2.imwrite(filename, visualization)
print(f"Created {filename}")

return visualization

def main():
# Create test image
test_image = create_test_image()
cv2.imwrite('original.png', test_image)
print("Created original.png")

# Initialize strategy with self-correction
strategy = CursorReplayStrategy(
dot_radius=6,
dot_color=(0, 0, 255), # Red in BGR
enable_self_correction=True,
confidence_threshold=0.7
)

# Test cases
test_cases = [
("Button Click", (100, 80)), # Near button center
("Checkbox", (210, 55)), # Slightly off checkbox
("Text Input", (200, 165)), # Text input field
("Dropdown", (140, 215)), # Dropdown arrow
("Out of Bounds", (500, 500)) # Out of bounds
]

# Run tests
for name, coords in test_cases:
test_cursor_strategy(test_image, strategy, name, coords)

print("\nAll visualizations have been created!")

if __name__ == "__main__":
main()
Binary file added examples/visualization_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_button_click.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_checkbox.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_dropdown.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_out_of_bounds.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_text_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 3 additions & 18 deletions openadapt/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,4 @@
"""Package containing different replay strategies.
from .base import VanillaReplayStrategy
from .cursor_replay import CursorReplayStrategy

Module: __init__.py
"""

# flake8: noqa

from openadapt.strategies.base import BaseReplayStrategy
from openadapt.strategies.visual_browser import VisualBrowserReplayStrategy

# disabled because importing is expensive
# from openadapt.strategies.demo import DemoReplayStrategy
from openadapt.strategies.naive import NaiveReplayStrategy
from openadapt.strategies.segment import SegmentReplayStrategy
from openadapt.strategies.stateful import StatefulReplayStrategy
from openadapt.strategies.vanilla import VanillaReplayStrategy
from openadapt.strategies.visual import VisualReplayStrategy

# add more strategies here
__all__ = ['VanillaReplayStrategy', 'CursorReplayStrategy']
142 changes: 142 additions & 0 deletions openadapt/strategies/cursor_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np
import cv2
from typing import Tuple, Optional
from .base import VanillaReplayStrategy

class CursorReplayStrategy(VanillaReplayStrategy):
def __init__(self,
dot_radius: int = 5,
dot_color: Tuple[int, int, int] = (0, 0, 255),
enable_self_correction: bool = False,
confidence_threshold: float = 0.7):
"""Initialize the CursorReplayStrategy.
Args:
dot_radius: Radius of the cursor dot in pixels
dot_color: Color of the dot in BGR format
enable_self_correction: Whether to enable visual self-correction
confidence_threshold: Threshold for self-correction confidence
"""
self.dot_radius = dot_radius
self.dot_color = dot_color
self.enable_self_correction = enable_self_correction
self.confidence_threshold = confidence_threshold
self.last_correction = None

def _analyze_target_region(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> Optional[Tuple[int, int]]:
"""Analyze the region around the target coordinates for better positioning.
Uses edge detection and contour analysis to find better click targets.
"""
x, y = target_coords
region_size = self.dot_radius * 4

# Extract region of interest
x1 = max(0, x - region_size)
y1 = max(0, y - region_size)
x2 = min(screenshot.shape[1], x + region_size)
y2 = min(screenshot.shape[0], y + region_size)
roi = screenshot[y1:y2, x1:x2]

if roi.size == 0:
return None

# Convert to grayscale
gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)

# Edge detection
edges = cv2.Canny(gray, 50, 150)

# Find contours
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

if not contours:
return None

# Find the largest contour near the center
center_x, center_y = region_size, region_size
best_contour = None
min_dist = float('inf')

for contour in contours:
M = cv2.moments(contour)
if M["m00"] == 0:
continue

cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])

dist = np.sqrt((cx - center_x)**2 + (cy - center_y)**2)
if dist < min_dist:
min_dist = dist
best_contour = contour

if best_contour is None:
return None

# Get the center of the best contour
M = cv2.moments(best_contour)
cx = int(M["m10"] / M["m00"]) + x1
cy = int(M["m01"] / M["m00"]) + y1

return (cx, cy)

def process_target(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> Tuple[int, int]:
"""Process the target coordinates with optional self-correction.
Args:
screenshot: The current screenshot
target_coords: The original target coordinates
Returns:
Tuple[int, int]: The processed target coordinates
"""
# Ensure coordinates are within bounds
x, y = target_coords
height, width = screenshot.shape[:2]
x = min(max(0, x), width - 1)
y = min(max(0, y), height - 1)

if not self.enable_self_correction:
return (x, y)

# Analyze the target region for better positioning
corrected_coords = self._analyze_target_region(screenshot, (x, y))

if corrected_coords is not None:
# Store the correction for visualization
self.last_correction = {
'original': (x, y),
'corrected': corrected_coords
}
return corrected_coords

return (x, y)

def visualize_target(self, screenshot: np.ndarray, target_coords: Tuple[int, int]) -> np.ndarray:
"""Visualize the target with a red dot and optional correction path.
Args:
screenshot: The current screenshot
target_coords: The target coordinates
Returns:
np.ndarray: The screenshot with visualization overlay
"""
visualization = screenshot.copy()

# Draw correction path if available
if self.enable_self_correction and self.last_correction is not None:
orig = self.last_correction['original']
corr = self.last_correction['corrected']

# Draw path from original to corrected position
cv2.line(visualization, orig, corr, (0, 255, 0), 1)
# Draw original position (smaller, yellow dot)
cv2.circle(visualization, orig, self.dot_radius-2, (0, 255, 255), -1)

# Draw target dot with white outline for visibility
cv2.circle(visualization, target_coords, self.dot_radius+2, (255, 255, 255), -1)
cv2.circle(visualization, target_coords, self.dot_radius, self.dot_color, -1)

return visualization
53 changes: 53 additions & 0 deletions tests/test_cursor_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import numpy as np
from replay_strategies.cursor_replay import CursorReplayStrategy

@pytest.fixture
def strategy():
return CursorReplayStrategy()

@pytest.fixture
def test_image():
return np.zeros((100, 100, 3), dtype=np.uint8)

def test_initialization():
strategy = CursorReplayStrategy(
dot_radius=10,
dot_color=(255, 0, 0),
enable_self_correction=True
)
assert strategy.dot_radius == 10
assert strategy.dot_color == (255, 0, 0)
assert strategy.enable_self_correction is True

def test_process_target_without_correction(strategy, test_image):
coords = (50, 50)
processed_coords = strategy.process_target(test_image, coords)
assert processed_coords == coords

def test_process_target_bounds_checking(strategy, test_image):
# Test coordinates outside image bounds
coords = (150, 150)
processed_coords = strategy.process_target(test_image, coords)
assert processed_coords[0] < 100
assert processed_coords[1] < 100

def test_visualize_target(strategy, test_image):
coords = (50, 50)
visualization = strategy.visualize_target(test_image, coords)

# Check that visualization has changed the image
assert not np.array_equal(visualization, test_image)

# Check that the dot was drawn (should have non-zero pixels)
assert np.sum(visualization) > 0

def test_self_correction_enabled():
strategy = CursorReplayStrategy(enable_self_correction=True)
test_image = np.zeros((100, 100, 3), dtype=np.uint8)
coords = (50, 50)

processed_coords = strategy.process_target(test_image, coords)
assert isinstance(processed_coords, tuple)
assert len(processed_coords) == 2
assert all(isinstance(x, (int, np.integer)) for x in processed_coords)

0 comments on commit 3482e8e

Please sign in to comment.