From 1fa30949f74fa3d52e800854892e8d1c8e89504a Mon Sep 17 00:00:00 2001 From: Alexis Deprez Date: Wed, 2 Oct 2024 21:41:42 +0200 Subject: [PATCH] simplify driver implementation --- .../lavague/drivers/playwright/base.py | 76 +- .../lavague/drivers/selenium/__init__.py | 1 - .../lavague/drivers/selenium/base.py | 848 +++++------------- .../lavague/drivers/selenium/node.py | 106 +++ lavague-sdk/lavague/sdk/base_driver.py | 748 --------------- .../lavague/sdk/base_driver/__init__.py | 1 + lavague-sdk/lavague/sdk/base_driver/base.py | 277 ++++++ .../lavague/sdk/base_driver/interaction.py | 59 ++ .../lavague/sdk/base_driver/javascript.py | 319 +++++++ lavague-sdk/lavague/sdk/base_driver/node.py | 57 ++ lavague-sdk/lavague/sdk/client.py | 41 +- lavague-sdk/lavague/sdk/exceptions.py | 5 + lavague-sdk/lavague/sdk/trajectory/base.py | 7 +- lavague-sdk/lavague/sdk/trajectory/model.py | 1 + .../lavague/sdk/utilities/version_checker.py | 2 + 15 files changed, 1101 insertions(+), 1447 deletions(-) create mode 100644 lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py delete mode 100644 lavague-sdk/lavague/sdk/base_driver.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/__init__.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/base.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/interaction.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/javascript.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/node.py diff --git a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py index 78bbf337..dc6db7a8 100644 --- a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py +++ b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py @@ -4,12 +4,15 @@ from PIL import Image from typing import Callable, Optional, Any, Mapping, Dict, List from playwright.sync_api import Page, Locator -from lavague.sdk.base_driver import ( - BaseDriver, +from lavague.sdk.base_driver import BaseDriver +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, +) + +from lavague.sdk.base_driver.javascript import ( JS_GET_INTERACTIVES, JS_WAIT_DOM_IDLE, - PossibleInteractionsByXpath, - InteractionType, ) from lavague.sdk.exceptions import ( NoElementException, @@ -46,7 +49,7 @@ def __init__( # Before modifying this function, check if your changes are compatible with code_for_init which parses this code # these imports are necessary as they will be pasted to the output def default_init_code(self) -> Page: - from lavague.sdk.base_driver import JS_SETUP_GET_EVENTS + from lavague.sdk.base_driver.javascript import JS_SETUP_GET_EVENTS try: from playwright.sync_api import sync_playwright @@ -118,72 +121,9 @@ def get_html(self) -> str: def destroy(self) -> None: self.page.close() - def check_visibility(self, xpath: str) -> bool: - try: - locator = self.page.locator(f"xpath={xpath}") - return locator.is_visible() and locator.is_enabled() - except: - return False - def resolve_xpath(self, xpath) -> Locator: return self.page.locator(f"xpath={xpath}") - def get_highlighted_element(self, generated_code: str): - elements = [] - - data = json.loads(generated_code) - if not isinstance(data, List): - data = [data] - for item in data: - action_name = item["action"]["name"] - if action_name != "fail": - xpath = item["action"]["args"]["xpath"] - try: - elem = self.page.locator(f"xpath={xpath}") - elements.append(elem) - except: - pass - - if len(elements) == 0: - raise ValueError("No element found.") - - outputs = [] - for element in elements: - element: Locator - - bounding_box = {} - viewport_size = {} - - self.execute_script( - "arguments[0].setAttribute('style', arguments[1]);", - element, - "border: 2px solid red;", - ) - self.execute_script( - "arguments[0].scrollIntoView({block: 'center'});", element - ) - screenshot = self.get_screenshot_as_png() - - bounding_box["x1"] = element.bounding_box()["x"] - bounding_box["y1"] = element.bounding_box()["y"] - bounding_box["x2"] = bounding_box["x1"] + element.bounding_box()["width"] - bounding_box["y2"] = bounding_box["y1"] + element.bounding_box()["height"] - - viewport_size["width"] = self.execute_script("return window.innerWidth;") - viewport_size["height"] = self.execute_script("return window.innerHeight;") - screenshot = BytesIO(screenshot) - screenshot = Image.open(screenshot) - output = { - "screenshot": screenshot, - "bounding_box": bounding_box, - "viewport_size": viewport_size, - } - outputs.append(output) - return outputs - - def maximize_window(self) -> None: - pass - def exec_code( self, code: str, diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py index 1c340f2e..a8e41bbe 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py @@ -1,2 +1 @@ from lavague.drivers.selenium.base import SeleniumDriver -from lavague.drivers.selenium.base import BrowserbaseRemoteConnection diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py index bc14ddb1..bf3beb13 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py @@ -1,383 +1,233 @@ -import re -from typing import Any, Optional, Callable, Mapping, Dict, List -from selenium.webdriver.remote.webdriver import WebDriver -from selenium.webdriver.remote.shadowroot import ShadowRoot -from selenium.webdriver.common.by import By -from selenium.webdriver.common.keys import Keys -from selenium.common.exceptions import ( - NoSuchElementException, - WebDriverException, - ElementClickInterceptedException, - TimeoutException, +import json +import time +from typing import Callable, Dict, List, Optional + +from lavague.drivers.selenium.node import SeleniumNode +from lavague.sdk.action.navigation import NavigationOutput +from lavague.sdk.base_driver import BaseDriver +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, + ScrollDirection, ) -from selenium.webdriver.support.ui import Select, WebDriverWait -from selenium.webdriver.remote.webelement import WebElement -from selenium.webdriver.common.actions.wheel_input import ScrollOrigin -from lavague.sdk.base_driver import ( - BaseDriver, +from lavague.sdk.base_driver.javascript import ( + ATTACH_MOVE_LISTENER, JS_GET_INTERACTIVES, - JS_WAIT_DOM_IDLE, JS_GET_SCROLLABLE_PARENT, JS_GET_SHADOW_ROOTS, - PossibleInteractionsByXpath, - ScrollDirection, - InteractionType, - DOMNode, + JS_SETUP_GET_EVENTS, + JS_WAIT_DOM_IDLE, + REMOVE_HIGHLIGHT, + get_highlighter_style, ) from lavague.sdk.exceptions import ( CannotBackException, - NoElementException, - AmbiguousException, + NoPageException, +) + +from selenium.common.exceptions import ( + NoSuchElementException, + TimeoutException, ) -from PIL import Image -from io import BytesIO +from selenium.webdriver import Chrome from selenium.webdriver.chrome.options import Options from selenium.webdriver.common.action_chains import ActionChains -import time -import yaml -import json -from selenium.webdriver.remote.remote_connection import RemoteConnection -import requests -import os -from lavague.drivers.selenium.javascript import ( - ATTACH_MOVE_LISTENER, - get_highlighter_style, - REMOVE_HIGHLIGHT, -) +from selenium.webdriver.common.actions.wheel_input import ScrollOrigin +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.remote.webelement import WebElement +from selenium.webdriver.support.ui import Select, WebDriverWait -class SeleniumDriver(BaseDriver): +class SeleniumDriver(BaseDriver[SeleniumNode]): driver: WebDriver - last_hover_xpath: Optional[str] = None def __init__( self, - url: Optional[str] = None, - get_selenium_driver: Optional[Callable[[], WebDriver]] = None, + options: Optional[Options] = None, headless: bool = True, user_data_dir: Optional[str] = None, - width: Optional[int] = 1096, - height: Optional[int] = 1096, - options: Optional[Options] = None, - driver: Optional[WebDriver] = None, - log_waiting_time=False, waiting_completion_timeout=10, - remote_connection: Optional["BrowserbaseRemoteConnection"] = None, - ): - self.headless = headless - self.user_data_dir = user_data_dir - self.width = width - self.height = height - self.options = options - self.driver = driver - self.log_waiting_time = log_waiting_time + log_waiting_time=False, + user_agent="Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36", + auto_init=True, + ) -> None: self.waiting_completion_timeout = waiting_completion_timeout - self.remote_connection = remote_connection - super().__init__(url, get_selenium_driver) - - # Default code to init the driver. - # Before making any change to this, make sure it is compatible with code_for_init, which parses the code of this function - # These imports are necessary as they will be pasted to the output - def default_init_code(self) -> Any: - from selenium import webdriver - from selenium.webdriver.common.by import By - from selenium.webdriver.chrome.options import Options - from selenium.webdriver.common.keys import Keys - from selenium.webdriver.common.action_chains import ActionChains - from lavague.sdk.base_driver import JS_SETUP_GET_EVENTS - - if self.options: - chrome_options = self.options + self.log_waiting_time = log_waiting_time + if options: + self.options = options else: - chrome_options = Options() - if self.headless: - chrome_options.add_argument("--headless=new") - if self.user_data_dir: - chrome_options.add_argument(f"--user-data-dir={self.user_data_dir}") - else: - chrome_options.add_argument("--lang=en") - user_agent = "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36" - chrome_options.add_argument(f"user-agent={user_agent}") - chrome_options.add_argument("--no-sandbox") - chrome_options.page_load_strategy = "normal" - # allow access to cross origin iframes - chrome_options.add_argument("--disable-web-security") - chrome_options.add_argument("--disable-site-isolation-trials") - chrome_options.add_argument("--disable-notifications") - chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"}) - - if self.remote_connection: - chrome_options.add_experimental_option("debuggerAddress", "localhost:9223") - self.driver = webdriver.Remote( - self.remote_connection, options=chrome_options - ) - elif self.driver is None: - self.driver = webdriver.Chrome(options=chrome_options) - - # 538: browserbase implementation - move execute_cdp_cmd to inner block to avoid error - # AttributeError: 'WebDriver' object has no attribute 'execute_cdp_cmd' - self.driver.execute_cdp_cmd( - "Page.addScriptToEvaluateOnNewDocument", - {"source": JS_SETUP_GET_EVENTS}, - ) - self.resize_driver(self.width, self.height) - return self.driver - - def __enter__(self): - return self + self.options = Options() + if headless: + self.options.add_argument("--headless=new") + self.options.add_argument("--lang=en") + self.options.add_argument(f"user-agent={user_agent}") + self.options.add_argument("--disable-notifications") + if user_data_dir: + self.options.add_argument(f"--user-data-dir={user_data_dir}") + self.options.page_load_strategy = "normal" + self.options.add_argument("--no-sandbox") + self.options.add_argument("--disable-web-security") + self.options.add_argument("--disable-site-isolation-trials") + self.options.set_capability("goog:loggingPrefs", {"performance": "ALL"}) + if auto_init: + self.init() + + def init(self) -> None: + self.driver = Chrome(options=self.options) + self.driver.execute_cdp_cmd( + "Page.addScriptToEvaluateOnNewDocument", + {"source": JS_SETUP_GET_EVENTS}, + ) - def __exit__(self, *args): - self.destroy() + def execute(self, action: NavigationOutput) -> None: + """Execute an action""" + with self.resolve_xpath(action.xpath) as node: + match action.navigation_command: + case InteractionType.CLICK: + node.element.click() + + case InteractionType.TYPE: + value = action.value or "" + if node.element.tag_name == "input": + node.element.clear() + if node.element.tag_name == "select": + select = Select(node.element) + try: + select.select_by_value(value) + except NoSuchElementException: + select.select_by_visible_text(value) + else: + node.element.send_keys(value) + + case InteractionType.HOVER: + ActionChains(self.driver).move_to_element(node.element).perform() + + case InteractionType.SCROLL: + direction = ScrollDirection.from_string(action.value or "DOWN") + self.scroll(action.xpath, direction) - def get_driver(self) -> WebDriver: - return self.driver + def destroy(self) -> None: + """Cleanly destroy the underlying driver""" + self.driver.quit() - def resize_driver(self, width, height) -> None: - if width is None and height is None: - return None - # Selenium is only being able to set window size and not viewport size + def resize_driver(self, width: int, height: int): + """Resize the viewport to a targeted height and width""" self.driver.set_window_size(width, height) viewport_height = self.driver.execute_script("return window.innerHeight;") - height_difference = height - viewport_height self.driver.set_window_size(width, height + height_difference) - self.width = width - self.height = height - - def code_for_resize(self, width, height) -> str: - return f""" -driver.set_window_size({width}, {height}) -viewport_height = driver.execute_script("return window.innerHeight;") -height_difference = {height} - viewport_height -driver.set_window_size({width}, {height} + height_difference) -""" - def get_url(self) -> Optional[str]: + def get_url(self) -> str: + """Get the url of the current page, raise NoPageException if no page is loaded""" if self.driver.current_url == "data:,": - return None + raise NoPageException() return self.driver.current_url - def code_for_get(self, url: str) -> str: - return f'driver.get("{url}")' - def get(self, url: str) -> None: + """Navigate to the url""" self.driver.get(url) def back(self) -> None: - if self.driver.execute_script("return !document.referrer"): + """Navigate back, raise CannotBackException if history root is reached""" + if self.driver.execute_script("return !document.referrer;"): raise CannotBackException() self.driver.back() - def code_for_back(self) -> None: - return "driver.back()" - def get_html(self) -> str: + """ + Returns the HTML of the current page. + If clean is True, We remove unnecessary tags and attributes from the HTML. + Clean HTMLs are easier to process for the LLM. + """ return self.driver.page_source - def get_screenshot_as_png(self) -> bytes: - return self.driver.get_screenshot_as_png() - - def destroy(self) -> None: - self.driver.quit() - - def maximize_window(self) -> None: - self.driver.maximize_window() + def get_tabs(self) -> str: + """Return description of the tabs opened with the current tab being focused. - def check_visibility(self, xpath: str) -> bool: - try: - # Done manually here to avoid issues - element = self.resolve_xpath(xpath).element - res = ( - element is not None and element.is_displayed() and element.is_enabled() - ) - self.switch_default_frame() - return res - except: - return False + Example of output: + Tabs opened: + 0 - Overview - OpenAI API + 1 - [CURRENT] Nos destinations Train - SNCF Connect + """ + window_handles = self.driver.window_handles + # Store the current window handle (focused tab) + current_handle = self.driver.current_window_handle + tab_info = [] + tab_id = 0 - def get_viewport_size(self) -> dict: - viewport_size = {} - viewport_size["width"] = self.execute_script("return window.innerWidth;") - viewport_size["height"] = self.execute_script("return window.innerHeight;") - return viewport_size + for handle in window_handles: + # Switch to each tab + self.driver.switch_to.window(handle) - def get_highlighted_element(self, generated_code: str): - elements = [] - - # Ensures that numeric values are quoted - generated_code = quote_numeric_yaml_values(generated_code) - - data = yaml.safe_load(generated_code) - if not isinstance(data, List): - data = [data] - for item in data: - for action in item["actions"]: - try: - xpath = action["action"]["args"]["xpath"] - elem = self.driver.find_element(By.XPATH, xpath) - elements.append(elem) - except: - pass - - outputs = [] - for element in elements: - element: WebElement - - bounding_box = {} - - self.execute_script( - "arguments[0].setAttribute('style', arguments[1]);", - element, - "border: 2px solid red;", - ) - self.execute_script( - "arguments[0].scrollIntoView({block: 'center'});", element - ) - screenshot = self.get_screenshot_as_png() - - bounding_box["x1"] = element.location["x"] - bounding_box["y1"] = element.location["y"] - bounding_box["x2"] = bounding_box["x1"] + element.size["width"] - bounding_box["y2"] = bounding_box["y1"] + element.size["height"] - - screenshot = BytesIO(screenshot) - screenshot = Image.open(screenshot) - output = { - "screenshot": screenshot, - "bounding_box": bounding_box, - "viewport_size": self.get_viewport_size(), - } - outputs.append(output) - return outputs - - def switch_frame(self, xpath): - iframe = self.driver.find_element(By.XPATH, xpath) - self.driver.switch_to.frame(iframe) + # Get the title of the current tab + title = self.driver.title - def switch_default_frame(self) -> None: - self.driver.switch_to.default_content() + # Check if this is the focused tab + if handle == current_handle: + tab_info.append(f"{tab_id} - [CURRENT] {title}") + else: + tab_info.append(f"{tab_id} - {title}") - def switch_parent_frame(self) -> None: - self.driver.switch_to.parent_frame() + tab_id += 1 - def resolve_xpath(self, xpath: Optional[str]) -> "SeleniumNode": - return SeleniumNode(xpath, self) + # Switch back to the original tab + self.driver.switch_to.window(current_handle) - def exec_code( - self, - code: str, - globals: dict[str, Any] = None, - locals: Mapping[str, object] = None, - ): - # Ensures that numeric values are quoted to avoid issues with YAML parsing - code = quote_numeric_yaml_values(code) - - data = yaml.safe_load(code) - if not isinstance(data, List): - data = [data] - for item in data: - for action in item["actions"]: - action_name = action["action"]["name"] - args = action["action"]["args"] - xpath = args.get("xpath", None) - - match action_name: - case "click": - self.click(xpath) - case "setValue": - self.set_value(xpath, args["value"]) - case "setValueAndEnter": - self.set_value(xpath, args["value"], True) - case "dropdownSelect": - self.dropdown_select(xpath, args["value"]) - case "hover": - self.hover(xpath) - case "scroll": - self.scroll( - xpath, - ScrollDirection.from_string(args.get("value", "DOWN")), - ) - case "failNoElement": - raise NoElementException("No element: " + args["value"]) - case "failAmbiguous": - raise AmbiguousException("Ambiguous: " + args["value"]) - case _: - raise ValueError(f"Unknown action: {action_name}") - - self.wait_for_idle() - - def execute_script(self, js_code: str, *args) -> Any: - return self.driver.execute_script(js_code, *args) - - def scroll_up(self): - self.scroll(direction=ScrollDirection.UP) - - def scroll_down(self): - self.scroll(direction=ScrollDirection.DOWN) - - def code_for_execute_script(self, js_code: str, *args) -> str: - return ( - f"driver.execute_script({js_code}, {', '.join(str(arg) for arg in args)})" - ) + tab_info = "\n".join(tab_info) + tab_info = "Tabs opened:\n" + tab_info + return tab_info - def hover(self, xpath: str): - with self.resolve_xpath(xpath) as element_resolved: - self.last_hover_xpath = xpath - ActionChains(self.driver).move_to_element( - element_resolved.element - ).perform() + def switch_tab(self, tab_id: int) -> None: + """Switch to the tab with the given id""" + window_handles = self.driver.window_handles + self.driver.switch_to.window(window_handles[tab_id]) - def scroll_page(self, direction: ScrollDirection = ScrollDirection.DOWN): - self.driver.execute_script(direction.get_page_script()) + def resolve_xpath(self, xpath: str): + """ + Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary + """ + return SeleniumNode(self.driver, xpath) - def get_scroll_anchor(self, xpath_anchor: Optional[str] = None) -> WebElement: - with self.resolve_xpath( - xpath_anchor or self.last_hover_xpath - ) as element_resolved: - element = element_resolved.element - parent = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, element) - scroll_anchor = parent or element - return scroll_anchor + def get_viewport_size(self) -> dict: + """Return viewport size as {"width": int, "height": int}""" + viewport_size = {} + viewport_size["width"] = self.driver.execute_script("return window.innerWidth;") + viewport_size["height"] = self.driver.execute_script( + "return window.innerHeight;" + ) + return viewport_size - def get_scroll_container_size(self, scroll_anchor: WebElement): - container = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, scroll_anchor) - if container: - return ( - self.driver.execute_script( - "const r = arguments[0].getBoundingClientRect(); return [r.width, r.height]", - scroll_anchor, - ), - True, - ) - return ( - self.driver.execute_script( - "return [window.innerWidth, window.innerHeight]", - ), + def get_possible_interactions( + self, + in_viewport=True, + foreground_only=True, + types: List[InteractionType] = [ + InteractionType.CLICK, + InteractionType.TYPE, + InteractionType.HOVER, + ], + ) -> PossibleInteractionsByXpath: + """Get elements that can be interacted with as a dictionary mapped by xpath""" + exe: Dict[str, List[str]] = self.driver.execute_script( + JS_GET_INTERACTIVES, + in_viewport, + foreground_only, False, + [t.name for t in types], ) + res = dict() + for k, v in exe.items(): + res[k] = set(InteractionType[i] for i in v) + return res - def is_bottom_of_page(self) -> bool: - return not self.can_scroll(direction=ScrollDirection.DOWN) - - def can_scroll( - self, - xpath_anchor: Optional[str] = None, - direction: ScrollDirection = ScrollDirection.DOWN, - ) -> bool: - try: - scroll_anchor = self.get_scroll_anchor(xpath_anchor) - if scroll_anchor: - return self.driver.execute_script( - direction.get_script_element_is_scrollable(), - scroll_anchor, - ) - except NoSuchElementException: - pass - return self.driver.execute_script(direction.get_script_page_is_scrollable()) + def scroll_into_view(self, xpath: str): + with self.resolve_xpath(xpath) as node: + self.driver.execute_script("arguments[0].scrollIntoView()", node.element) def scroll( self, - xpath_anchor: Optional[str] = None, + xpath_anchor: Optional[str] = "/html/body", direction: ScrollDirection = ScrollDirection.DOWN, scroll_factor=0.75, ): @@ -398,91 +248,38 @@ def scroll( ActionChains(self.driver).scroll_by_amount( scroll_xy[0], scroll_xy[1] ).perform() - if xpath_anchor: - self.last_hover_xpath = xpath_anchor except NoSuchElementException: self.scroll_page(direction) - def click(self, xpath: str): - with self.resolve_xpath(xpath) as element_resolved: - element = element_resolved.element - self.last_hover_xpath = xpath - try: - element.click() - except ElementClickInterceptedException: - try: - # Move to the element and click at its position - ActionChains(self.driver).move_to_element(element).click().perform() - except WebDriverException as click_error: - raise Exception( - f"Failed to click at element coordinates of {xpath} : {str(click_error)}" - ) - except Exception as e: - import traceback - - traceback.print_exc() - raise Exception( - f"An unexpected error occurred when trying to click on {xpath}: {str(e)}" - ) - - def set_value(self, xpath: str, value: str, enter: bool = False): - with self.resolve_xpath(xpath) as element_resolved: - elem = element_resolved.element - try: - self.last_hover_xpath = xpath - if elem.tag_name == "select": - # use the dropdown_select to set the value of a select - return self.dropdown_select(xpath, value) - if elem.tag_name == "input" and elem.get_attribute("type") == "file": - # set the value of a file input - return self.upload_file(xpath, value) - - elem.clear() - except: - # might not be a clearable element, but global click + send keys can still success - pass - - self.click(xpath) - - ( - ActionChains(self.driver) - .key_down(Keys.CONTROL) - .send_keys("a") - .key_up(Keys.CONTROL) - .send_keys(Keys.DELETE) # clear the input field - .send_keys(value) - .perform() - ) - if enter: - ActionChains(self.driver).send_keys(Keys.ENTER).perform() + def scroll_page(self, direction: ScrollDirection = ScrollDirection.DOWN): + self.driver.execute_script(direction.get_page_script()) - def dropdown_select(self, xpath: str, value: str): - with self.resolve_xpath(xpath) as element_resolved: - element = element_resolved.element - self.last_hover_xpath = xpath - - if element.tag_name != "select": - print( - f"Cannot use dropdown_select on {element.tag_name}, falling back to simple click on {xpath}" - ) - return self.click(xpath) - - select = Select(element) - try: - select.select_by_value(value) - except NoSuchElementException: - select.select_by_visible_text(value) - - def upload_file(self, xpath: str, file_path: str): - with self.resolve_xpath(xpath) as element_resolved: + def get_scroll_anchor(self, xpath_anchor: Optional[str] = None) -> WebElement: + with self.resolve_xpath(xpath_anchor or "/html/body") as element_resolved: element = element_resolved.element - self.last_hover_xpath = xpath - element.send_keys(file_path) + parent = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, element) + scroll_anchor = parent or element + return scroll_anchor - def perform_wait(self, duration: float): - import time + def get_scroll_container_size(self, scroll_anchor: WebElement): + container = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, scroll_anchor) + if container: + return ( + self.driver.execute_script( + "const r = arguments[0].getBoundingClientRect(); return [r.width, r.height]", + scroll_anchor, + ), + True, + ) + return ( + self.driver.execute_script( + "return [window.innerWidth, window.innerHeight]", + ), + False, + ) - time.sleep(duration) + def wait_for_dom_stable(self, timeout: float = 10): + self.driver.execute_script(JS_WAIT_DOM_IDLE, max(0, round(timeout * 1000))) def is_idle(self): active = 0 @@ -510,9 +307,6 @@ def is_idle(self): return len(request_ids) == 0 and active <= 0 - def wait_for_dom_stable(self, timeout: float = 10): - self.driver.execute_script(JS_WAIT_DOM_IDLE, max(0, round(timeout * 1000))) - def wait_for_idle(self): t = time.time() elapsed = 0 @@ -532,49 +326,39 @@ def wait_for_idle(self): ) def get_capability(self) -> str: + """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" return SELENIUM_PROMPT_TEMPLATE - def get_tabs(self): - driver = self.driver - window_handles = driver.window_handles - # Store the current window handle (focused tab) - current_handle = driver.current_window_handle - tab_info = [] - tab_id = 0 - - for handle in window_handles: - # Switch to each tab - driver.switch_to.window(handle) - - # Get the title of the current tab - title = driver.title - - # Check if this is the focused tab - if handle == current_handle: - tab_info.append(f"{tab_id} - [CURRENT] {title}") - else: - tab_info.append(f"{tab_id} - {title}") - - tab_id += 1 - - # Switch back to the original tab - driver.switch_to.window(current_handle) + def get_screenshot_as_png(self) -> bytes: + return self.driver.get_screenshot_as_png() - tab_info = "\n".join(tab_info) - tab_info = "Tabs opened:\n" + tab_info - return tab_info + def get_shadow_roots(self) -> Dict[str, str]: + """Return a dictionary of shadow roots HTML by xpath""" + return self.driver.execute_script(JS_GET_SHADOW_ROOTS) - def switch_tab(self, tab_id: int): - driver = self.driver - window_handles = driver.window_handles + def get_nodes(self, xpaths: List[str]) -> List[SeleniumNode]: + return [SeleniumNode(self.driver, xpath) for xpath in xpaths] - # Switch to the tab with the given id - driver.switch_to.window(window_handles[tab_id]) + def highlight_nodes( + self, xpaths: List[str], color: str = "red", label=False + ) -> Callable: + nodes = self.get_nodes(xpaths) + self.driver.execute_script(ATTACH_MOVE_LISTENER) + set_style = get_highlighter_style(color, label) + self.exec_script_for_nodes( + nodes, "arguments[0].forEach((a, i) => { " + set_style + "})" + ) + return self._add_highlighted_destructors( + lambda: self.remove_nodes_highlight(xpaths) + ) - def get_nodes(self, xpaths: List[str]) -> List["SeleniumNode"]: - return [SeleniumNode(xpath, self) for xpath in xpaths] + def remove_nodes_highlight(self, xpaths: List[str]): + self.exec_script_for_nodes( + self.get_nodes(xpaths), + REMOVE_HIGHLIGHT, + ) - def exec_script_for_nodes(self, nodes: List["SeleniumNode"], script: str): + def exec_script_for_nodes(self, nodes: List[SeleniumNode], script: str): standard_nodes: List[SeleniumNode] = [] special_nodes: List[SeleniumNode] = [] @@ -602,190 +386,14 @@ def exec_script_for_nodes(self, nodes: List["SeleniumNode"], script: str): script, [n.element], ) - self.switch_default_frame() - - def remove_nodes_highlight(self, xpaths: List[str]): - self.exec_script_for_nodes( - self.get_nodes(xpaths), - REMOVE_HIGHLIGHT, - ) - - def highlight_nodes( - self, xpaths: List[str], color: str = "red", label=False - ) -> Callable: - nodes = self.get_nodes(xpaths) - self.driver.execute_script(ATTACH_MOVE_LISTENER) - set_style = get_highlighter_style(color, label) - self.exec_script_for_nodes( - nodes, "arguments[0].forEach((a, i) => { " + set_style + "})" - ) - return self._add_highlighted_destructors( - lambda: self.remove_nodes_highlight(xpaths) - ) + self.driver.switch_to.default_content() - def get_possible_interactions( - self, - in_viewport=True, - foreground_only=True, - types: List[InteractionType] = [ - InteractionType.CLICK, - InteractionType.TYPE, - InteractionType.HOVER, - ], - ) -> PossibleInteractionsByXpath: - exe: Dict[str, List[str]] = self.driver.execute_script( - JS_GET_INTERACTIVES, - in_viewport, - foreground_only, - False, - [t.name for t in types], - ) - res = dict() - for k, v in exe.items(): - res[k] = set(InteractionType[i] for i in v) - return res - - def get_in_viewport(self): - res: Dict[str, List[str]] = self.driver.execute_script( - JS_GET_INTERACTIVES, - True, - True, - True, - ) - return list(res.keys()) - - def get_shadow_roots(self) -> Dict[str, str]: - return self.driver.execute_script(JS_GET_SHADOW_ROOTS) - - -class SeleniumNode(DOMNode): - def __init__( - self, - xpath: Optional[str], - driver: SeleniumDriver, - element: Optional[WebElement] = None, - ) -> None: - if not xpath: - raise NoSuchElementException("xpath is missing") - self.xpath = xpath - self._driver = driver - if element: - self._element = element - super().__init__() - - @property - def element(self) -> Optional[WebElement]: - if not hasattr(self, "_element"): - print("WARN: DOMNode context manager missing") - self.__enter__() - return self._element - - @property - def value(self) -> Any: - elem = self.element - return elem.get_attribute("value") if elem else None - - def highlight(self, color: str = "red", bounding_box=True): - self._driver.highlight_nodes([self.xpath], color, bounding_box) - return self - - def clear(self): - self._driver.remove_nodes_highlight([self.xpath]) - return self - - def take_screenshot(self): - with self: - if self.element: - try: - return Image.open(BytesIO(self.element.screenshot_as_png)) - except WebDriverException: - pass - return Image.new("RGB", (0, 0)) - - def get_html(self): - with self: - return self._driver.driver.execute_script( - "return arguments[0].outerHTML", self.element - ) - - def __enter__(self): - if hasattr(self, "_element"): - return self - - self._element = None - if not self.xpath: - return self - - root = self._driver.driver - local_xpath = self.xpath - - def find_element(xpath): - try: - if isinstance(root, ShadowRoot): - # Shadow root does not support find_element with xpath - css_selector = re.sub( - r"\[([0-9]+)\]", - r":nth-of-type(\1)", - xpath[1:].replace("/", " > "), - ) - return root.find_element(By.CSS_SELECTOR, css_selector) - return root.find_element(By.XPATH, xpath) - except Exception: - return None - - while local_xpath: - match = re.search(r"/iframe|//", local_xpath) - - if match: - before, sep, local_xpath = local_xpath.partition(match.group()) - if sep == "/iframe": - self._driver.switch_frame(before + sep) - elif sep == "//": - custom_element = find_element(before) - if not custom_element: - break - root = custom_element.shadow_root - local_xpath = "/" + local_xpath - else: - break - - self._element = find_element(local_xpath) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if hasattr(self, "_element"): - self._driver.switch_default_frame() - del self._element - - -class BrowserbaseRemoteConnection(RemoteConnection): - _session_id = None + def switch_frame(self, xpath: str) -> None: + iframe = self.driver.find_element(By.XPATH, xpath) + self.driver.switch_to.frame(iframe) - def __init__( - self, - remote_server_addr: str, - api_key: Optional[str] = None, - project_id: Optional[str] = None, - ): - super().__init__(remote_server_addr) - self.api_key = api_key or os.environ["BROWSERBASE_API_KEY"] - self.project_id = project_id or os.environ["BROWSERBASE_PROJECT_ID"] - - def get_remote_connection_headers(self, parsed_url, keep_alive=False): - if self._session_id is None: - self._session_id = self._create_session() - headers = super().get_remote_connection_headers(parsed_url, keep_alive) - headers.update({"x-bb-api-key": self.api_key}) - headers.update({"session-id": self._session_id}) - return headers - - def _create_session(self): - url = "https://www.browserbase.com/v1/sessions" - headers = {"Content-Type": "application/json", "x-bb-api-key": self.api_key} - response = requests.post( - url, json={"projectId": self.project_id}, headers=headers - ) - return response.json()["id"] + def switch_parent_frame(self) -> None: + self.driver.switch_to.parent_frame() SELENIUM_PROMPT_TEMPLATE = """ diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py new file mode 100644 index 00000000..2f37140f --- /dev/null +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py @@ -0,0 +1,106 @@ +import re +from io import BytesIO +from typing import Optional + +from lavague.sdk.base_driver import DOMNode +from lavague.sdk.exceptions import NoElementException +from PIL import Image + +from selenium.common.exceptions import WebDriverException +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.shadowroot import ShadowRoot +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.remote.webelement import WebElement + + +class SeleniumNode(DOMNode[WebElement]): + def __init__( + self, + driver: WebDriver, + xpath: str, + element: Optional[WebElement] = None, + ) -> None: + self.driver = driver + self.xpath = xpath + if element: + self._element = element + super().__init__() + + @property + def element(self) -> WebElement: + if not hasattr(self, "_element"): + print("WARN: DOMNode context manager missing") + self.__enter__() + if self._element is None: + raise NoElementException() + return self._element + + @property + def value(self) -> Optional[str]: + return self.element.get_attribute("value") + + @property + def text(self) -> str: + return self.element.text + + @property + def outer_html(self) -> str: + return self.driver.execute_script("return arguments[0].outerHTML", self.element) + + @property + def inner_html(self) -> str: + return self.driver.execute_script("return arguments[0].innerHTML", self.element) + + def take_screenshot(self): + with self: + if self.element: + try: + return Image.open(BytesIO(self.element.screenshot_as_png)) + except WebDriverException: + pass + return Image.new("RGB", (0, 0)) + + def enter_context(self): + if hasattr(self, "_element"): + return + + root = self.driver + local_xpath = self.xpath + + def find_element(xpath): + try: + if isinstance(root, ShadowRoot): + # Shadow root does not support find_element with xpath + css_selector = re.sub( + r"\[([0-9]+)\]", + r":nth-of-type(\1)", + xpath[1:].replace("/", " > "), + ) + return root.find_element(By.CSS_SELECTOR, css_selector) + return root.find_element(By.XPATH, xpath) + except Exception: + return None + + while local_xpath: + match = re.search(r"/iframe|//", local_xpath) + + if match: + before, sep, local_xpath = local_xpath.partition(match.group()) + if sep == "/iframe": + iframe = self.driver.find_element(By.XPATH, before + sep) + self.driver.switch_to.frame(iframe) + elif sep == "//": + custom_element = find_element(before) + if not custom_element: + break + root = custom_element.shadow_root + local_xpath = "/" + local_xpath + else: + break + + self._element = find_element(local_xpath) + + def exit_context(self): + if hasattr(self, "_element"): + self.driver.switch_to.default_content() + del self._element diff --git a/lavague-sdk/lavague/sdk/base_driver.py b/lavague-sdk/lavague/sdk/base_driver.py deleted file mode 100644 index ec398851..00000000 --- a/lavague-sdk/lavague/sdk/base_driver.py +++ /dev/null @@ -1,748 +0,0 @@ -from PIL import Image -import os -from pathlib import Path -import re -from typing import Any, Callable, Optional, Mapping, Dict, Set, List, Tuple, Union -from abc import ABC, abstractmethod -from enum import Enum -from datetime import datetime -import hashlib - - -class InteractionType(Enum): - CLICK = "click" - HOVER = "hover" - SCROLL = "scroll" - TYPE = "type" - - -PossibleInteractionsByXpath = Dict[str, Set[InteractionType]] - -r_get_xpaths_from_html = r'xpath=["\'](.*?)["\']' - -class ScrollDirection(Enum): - """Enum for the different scroll directions. Value is (x, y, dimension_index)""" - - LEFT = (-1, 0, 0) - RIGHT = (1, 0, 0) - UP = (0, -1, 1) - DOWN = (0, 1, 1) - - def get_scroll_xy( - self, dimension: List[float], scroll_factor: float = 0.75 - ) -> Tuple[int, int]: - size = dimension[self.value[2]] - return ( - round(self.value[0] * size * scroll_factor), - round(self.value[1] * size * scroll_factor), - ) - - def get_page_script(self, scroll_factor: float = 0.75) -> str: - return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" - - def get_script_element_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return arguments[0].scrollTop > 0" - case ScrollDirection.DOWN: - return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" - case ScrollDirection.LEFT: - return "return arguments[0].scrollLeft > 0" - case ScrollDirection.RIGHT: - return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" - - def get_script_page_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return window.scrollY > 0" - case ScrollDirection.DOWN: - return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" - case ScrollDirection.LEFT: - return "return window.scrollX > 0" - case ScrollDirection.RIGHT: - return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" - - @classmethod - def from_string(cls, name: str) -> "ScrollDirection": - return cls[name.upper().strip()] - - -class BaseDriver(ABC): - def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any]]): - """Init the driver with the init funtion, and then go to the desired url""" - self.init_function = ( - init_function if init_function is not None else self.default_init_code - ) - self.driver = self.init_function() - - # Flag to check if the page has been previously scanned to avoid erasing screenshots from previous scan - self.previously_scanned = False - - if url is not None: - self.get(url) - - async def connect(self) -> None: - """Connect to the driver""" - pass - - @abstractmethod - def default_init_code(self) -> Any: - """Init the driver, with the imports, since it will be pasted to the beginning of the output""" - pass - - @abstractmethod - def destroy(self) -> None: - """Cleanly destroy the underlying driver""" - pass - - @abstractmethod - def get_driver(self) -> Any: - """Return the expected variable name and the driver object""" - pass - - @abstractmethod - def resize_driver(driver, width, height): - """ - Resize the driver to a targeted height and width. - """ - - @abstractmethod - def get_url(self) -> Optional[str]: - """Get the url of the current page""" - pass - - @abstractmethod - def get(self, url: str) -> None: - """Navigate to the url""" - pass - - @abstractmethod - def code_for_get(self, url: str) -> str: - """Return the code to navigate to the url""" - pass - - @abstractmethod - def back(self) -> None: - """Navigate back""" - pass - - @abstractmethod - def maximize_window(self) -> None: - pass - - @abstractmethod - def code_for_back(self) -> None: - """Return driver specific code for going back""" - pass - - @abstractmethod - def get_html(self, clean: bool = True) -> str: - """ - Returns the HTML of the current page. - If clean is True, We remove unnecessary tags and attributes from the HTML. - Clean HTMLs are easier to process for the LLM. - """ - pass - - def get_tabs(self) -> str: - """Return description of the tabs opened with the current tab being focused. - - Example of output: - Tabs opened: - 0 - Overview - OpenAI API - 1 - [CURRENT] Nos destinations Train - SNCF Connect - """ - return "Tabs opened:\n 0 - [CURRENT] tab" - - def switch_tab(self, tab_id: int) -> None: - """Switch to the tab with the given id""" - pass - - def switch_frame(self, xpath) -> None: - """ - switch to the frame pointed at by the xpath - """ - raise NotImplementedError() - - def switch_default_frame(self) -> None: - """ - Switch back to the default frame - """ - raise NotImplementedError() - - def switch_parent_frame(self) -> None: - """ - Switch back to the parent frame - """ - raise NotImplementedError() - - @abstractmethod - def resolve_xpath(self, xpath) -> "DOMNode": - """ - Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary - """ - pass - - def save_screenshot(self, current_screenshot_folder: Path) -> str: - """Save the screenshot data to a file and return the path. If the screenshot already exists, return the path. If not save it to the folder.""" - - new_screenshot = self.get_screenshot_as_png() - hasher = hashlib.md5() - hasher.update(new_screenshot) - new_hash = hasher.hexdigest() - new_screenshot_name = f"{new_hash}.png" - new_screenshot_full_path = current_screenshot_folder / new_screenshot_name - - # If the screenshot does not exist, save it - if not new_screenshot_full_path.exists(): - with open(new_screenshot_full_path, "wb") as f: - f.write(new_screenshot) - return str(new_screenshot_full_path) - - def is_bottom_of_page(self) -> bool: - return self.execute_script( - "return (window.innerHeight + window.scrollY + 1) >= document.body.scrollHeight;" - ) - - def get_screenshots_whole_page(self, max_screenshots=30) -> list[str]: - """Take screenshots of the whole page""" - screenshot_paths = [] - - current_screenshot_folder = self.get_current_screenshot_folder() - - for i in range(max_screenshots): - # Saves a screenshot - screenshot_path = self.save_screenshot(current_screenshot_folder) - screenshot_paths.append(screenshot_path) - self.scroll_down() - self.wait_for_idle() - - if self.is_bottom_of_page(): - break - - self.previously_scanned = True - return screenshot_paths - - @abstractmethod - def get_possible_interactions( - self, - in_viewport=True, - foreground_only=True, - types: List[InteractionType] = [ - InteractionType.CLICK, - InteractionType.TYPE, - InteractionType.HOVER, - ], - ) -> PossibleInteractionsByXpath: - """Get elements that can be interacted with as a dictionary mapped by xpath""" - pass - - def get_in_viewport(self) -> List[str]: - """Get xpath of elements in the viewport""" - return [] - - def check_visibility(self, xpath: str) -> bool: - return True - - @abstractmethod - def get_viewport_size(self) -> dict: - """Return viewport size as {"width": int, "height": int}""" - pass - - @abstractmethod - def get_highlighted_element(self, generated_code: str): - """Return the page elements that generated code interact with""" - pass - - @abstractmethod - def exec_code( - self, - code: str, - globals: dict[str, Any] = None, - locals: Mapping[str, object] = None, - ): - """Exec generated code""" - pass - - @abstractmethod - def execute_script(self, js_code: str, *args) -> Any: - """Exec js script in DOM""" - pass - - @abstractmethod - def scroll( - self, - xpath_anchor: Optional[str], - direction: ScrollDirection, - scroll_factor=0.75, - ): - pass - - # TODO: Remove these methods as they are not used - @abstractmethod - def scroll_up(self): - pass - - @abstractmethod - def scroll_down(self): - pass - - @abstractmethod - def code_for_execute_script(self, js_code: str): - """return driver specific code to execute js script in DOM""" - pass - - @abstractmethod - def get_capability(self) -> str: - """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" - pass - - def get_obs(self) -> dict: - """Get the current observation of the driver""" - current_screenshot_folder = self.get_current_screenshot_folder() - - if not self.previously_scanned: - # If the last operation was not to scan the whole page, we clear the screenshot folder - try: - if os.path.isdir(current_screenshot_folder): - for filename in os.listdir(current_screenshot_folder): - file_path = os.path.join(current_screenshot_folder, filename) - try: - # Check if it's a file and then delete it - if os.path.isfile(file_path) or os.path.islink(file_path): - os.remove(file_path) - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - - except Exception as e: - raise Exception(f"Error while clearing screenshot folder: {e}") - else: - # If the last operation was to scan the whole page, we reset the flag - self.previously_scanned = False - - # We add labels to the scrollable elements - i_scroll = self.get_possible_interactions(types=[InteractionType.SCROLL]) - scrollables_xpaths = list(i_scroll.keys()) - - self.remove_highlight() - self.highlight_nodes(scrollables_xpaths, label=True) - - # We take a screenshot and computes its hash to see if it already exists - self.save_screenshot(current_screenshot_folder) - self.remove_highlight() - - url = self.get_url() - html = self.get_html() - obs = { - "html": html, - "screenshots_path": str(current_screenshot_folder), - "url": url, - "date": datetime.now().isoformat(), - "tab_info": self.get_tabs(), - } - - return obs - - def wait(self, duration): - import time - - time.sleep(duration) - - def wait_for_idle(self): - pass - - def get_current_screenshot_folder(self) -> Path: - url = self.get_url() - - if url is None: - url = "blank" - - screenshots_path = Path("./screenshots") - screenshots_path.mkdir(exist_ok=True) - - current_url = url.replace("://", "_").replace("/", "_") - hasher = hashlib.md5() - hasher.update(current_url.encode("utf-8")) - - current_screenshot_folder = screenshots_path / hasher.hexdigest() - current_screenshot_folder.mkdir(exist_ok=True) - return current_screenshot_folder - - @abstractmethod - def get_screenshot_as_png(self) -> bytes: - pass - - @abstractmethod - def get_shadow_roots(self) -> Dict[str, str]: - pass - - def get_nodes(self, xpaths: List[str]) -> List["DOMNode"]: - raise NotImplementedError("get_nodes not implemented") - - def get_nodes_from_html(self, html: str) -> List["DOMNode"]: - return self.get_nodes(re.findall(r_get_xpaths_from_html, html)) - - def highlight_node_from_xpath( - self, xpath: str, color: str = "red", label=False - ) -> Callable: - return self.highlight_nodes([xpath], color, label) - - def highlight_nodes( - self, xpaths: List[str], color: str = "red", label=False - ) -> Callable: - nodes = self.get_nodes(xpaths) - for n in nodes: - n.highlight(color) - return self._add_highlighted_destructors(lambda: [n.clear() for n in nodes]) - - def highlight_nodes_from_html( - self, html: str, color: str = "blue", label=False - ) -> Callable: - return self.highlight_nodes( - re.findall(r_get_xpaths_from_html, html), color, label - ) - - def remove_highlight(self): - if hasattr(self, "_highlight_destructors"): - for destructor in self._highlight_destructors: - destructor() - delattr(self, "_highlight_destructors") - - def _add_highlighted_destructors( - self, destructors: Union[List[Callable], Callable] - ) -> Callable: - if not hasattr(self, "_highlight_destructors"): - self._highlight_destructors = [] - if isinstance(destructors, Callable): - self._highlight_destructors.append(destructors) - return destructors - - self._highlight_destructors.extend(destructors) - return lambda: [d() for d in destructors] - - def highlight_interactive_nodes( - self, - *with_interactions: tuple[InteractionType], - color: str = "red", - in_viewport=True, - foreground_only=True, - label=False, - ): - if with_interactions is None or len(with_interactions) == 0: - return self.highlight_nodes( - list( - self.get_possible_interactions( - in_viewport=in_viewport, foreground_only=foreground_only - ).keys() - ), - color, - label, - ) - - return self.highlight_nodes( - [ - xpath - for xpath, interactions in self.get_possible_interactions( - in_viewport=in_viewport, foreground_only=foreground_only - ).items() - if set(interactions) & set(with_interactions) - ], - color, - label, - ) - - -class DOMNode(ABC): - @property - @abstractmethod - def element(self) -> Any: - pass - - @property - @abstractmethod - def value(self) -> Any: - pass - - @abstractmethod - def highlight(self, color: str = "red", bounding_box=True): - pass - - @abstractmethod - def clear(self): - return self - - @abstractmethod - def take_screenshot(self) -> Image.Image: - pass - - @abstractmethod - def get_html(self) -> str: - pass - - def __str__(self) -> str: - return self.get_html() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - - - -def js_wrap_function_call(fn: str): - return "(function(){" + fn + "})()" - - -JS_SETUP_GET_EVENTS = """ -(function() { - if (window && !window.getEventListeners) { - const targetProto = EventTarget.prototype; - targetProto._addEventListener = Element.prototype.addEventListener; - targetProto.addEventListener = function(a,b,c) { - this._addEventListener(a,b,c); - if(!this.eventListenerList) this.eventListenerList = {}; - if(!this.eventListenerList[a]) this.eventListenerList[a] = []; - this.eventListenerList[a].push(b); - }; - targetProto._removeEventListener = Element.prototype.removeEventListener; - targetProto.removeEventListener = function(a, b, c) { - this._removeEventListener(a, b, c); - if(this.eventListenerList && this.eventListenerList[a]) { - const index = this.eventListenerList[a].indexOf(b); - if (index > -1) { - this.eventListenerList[a].splice(index, 1); - if (!this.eventListenerList[a].length) { - delete this.eventListenerList[a]; - } - } - } - }; - window.getEventListeners = function(e) { - return (e && e.eventListenerList) || []; - } - } -})();""" - -JS_GET_INTERACTIVES = """ -const windowHeight = (window.innerHeight || document.documentElement.clientHeight); -const windowWidth = (window.innerWidth || document.documentElement.clientWidth); - -return (function(inViewport, foregroundOnly, nonInteractives, filterTypes) { - function getInteractions(e) { - const tag = e.tagName.toLowerCase(); - if (!e.checkVisibility() || (e.hasAttribute('disabled') && !nonInteractives) || e.hasAttribute('readonly') - || (tag === 'input' && e.getAttribute('type') === 'hidden') || tag === 'body') { - return []; - } - const rect = e.getBoundingClientRect(); - if (rect.width + rect.height < 5) { - return []; - } - const style = getComputedStyle(e) || {}; - if (style.display === 'none' || style.visibility === 'hidden') { - return []; - } - const events = window && typeof window.getEventListeners === 'function' ? window.getEventListeners(e) : []; - const role = e.getAttribute('role'); - const clickableInputs = ['submit', 'checkbox', 'radio', 'color', 'file', 'image', 'reset']; - function hasEvent(n) { - return events[n]?.length || e.hasAttribute('on' + n); - } - let evts = []; - if (hasEvent('keydown') || hasEvent('keyup') || hasEvent('keypress') || hasEvent('keydown') || hasEvent('input') || e.isContentEditable - || ( - (tag === 'input' || tag === 'textarea' || role === 'searchbox' || role === 'input') - ) && !clickableInputs.includes(e.getAttribute('type')) - ) { - evts.push('TYPE'); - } - if (['a', 'button', 'select'].includes(tag) || ['button', 'checkbox', 'select'].includes(role) - || hasEvent('click') || hasEvent('mousedown') || hasEvent('mouseup') || hasEvent('dblclick') - || style.cursor === 'pointer' - || e.hasAttribute('aria-haspopup') - || (tag === 'input' && clickableInputs.includes(e.getAttribute('type'))) - || (tag === 'label' && document.getElementById(e.getAttribute('for'))) - ) { - evts.push('CLICK'); - } - if ( - (hasEvent('scroll') || hasEvent('wheel') || style.overflow === 'auto' || style.overflow === 'scroll' || style.overflowY === 'auto' || style.overflowY === 'scroll') - && (e.scrollHeight > e.clientHeight || e.scrollWidth > e.clientWidth)) { - evts.push('SCROLL'); - } - if (filterTypes && evts.length) { - evts = evts.filter(t => filterTypes.includes(t)); - } - if (nonInteractives && evts.length === 0) { - evts.push('NONE'); - } - - if (inViewport) { - let rect = e.getBoundingClientRect(); - let iframe = e.ownerDocument.defaultView.frameElement; - while (iframe) { - const iframeRect = iframe.getBoundingClientRect(); - rect = { - top: rect.top + iframeRect.top, - left: rect.left + iframeRect.left, - bottom: rect.bottom + iframeRect.bottom, - right: rect.right + iframeRect.right, - width: rect.width, - height: rect.height - } - iframe = iframe.ownerDocument.defaultView.frameElement; - } - const elemCenter = { - x: Math.round(rect.left + rect.width / 2), - y: Math.round(rect.top + rect.height / 2) - }; - if (elemCenter.x < 0) return []; - if (elemCenter.x > windowWidth) return []; - if (elemCenter.y < 0) return []; - if (elemCenter.y > windowHeight) return []; - if (!foregroundOnly) return evts; // whenever to check for elements above - let pointContainer = document.elementFromPoint(elemCenter.x, elemCenter.y); - iframe = e.ownerDocument.defaultView.frameElement; - while (iframe) { - const iframeRect = iframe.getBoundingClientRect(); - pointContainer = iframe.contentDocument.elementFromPoint( - elemCenter.x - iframeRect.left, - elemCenter.y - iframeRect.top - ); - iframe = iframe.ownerDocument.defaultView.frameElement; - } - do { - if (pointContainer === e) return evts; - if (pointContainer == null) return evts; - } while (pointContainer = pointContainer.parentNode); - return []; - } - return evts; - } - - const results = {}; - function traverse(node, xpath) { - if (node.nodeType === Node.ELEMENT_NODE) { - const interactions = getInteractions(node); - if (interactions.length > 0) { - results[xpath] = interactions; - } - } - const countByTag = {}; - for (let child = node.firstChild; child; child = child.nextSibling) { - let tag = child.nodeName.toLowerCase(); - if (tag.includes(":")) continue; //namespace - let isLocal = ['svg'].includes(tag); - if (isLocal) { - tag = `*[local-name() = '${tag}']`; - } - countByTag[tag] = (countByTag[tag] || 0) + 1; - let childXpath = xpath + '/' + tag; - if (countByTag[tag] > 1) { - childXpath += '[' + countByTag[tag] + ']'; - } - if (tag === 'iframe') { - try { - traverse(child.contentWindow.document.body, childXpath + '/html/body'); - } catch (e) { - console.warn("iframe access blocked", child, e); - } - } else if (!isLocal) { - traverse(child, childXpath); - if (child.shadowRoot) { - traverse(child.shadowRoot, childXpath + '/'); - } - } - } - } - traverse(document.body, '/html/body'); - return results; -})(arguments?.[0], arguments?.[1], arguments?.[2], arguments?.[3]); -""" - -JS_WAIT_DOM_IDLE = """ -return new Promise(resolve => { - const timeout = arguments[0] || 10000; - const stabilityThreshold = arguments[1] || 100; - - let mutationObserver; - let timeoutId = null; - - const waitForIdle = () => { - if (timeoutId) clearTimeout(timeoutId); - timeoutId = setTimeout(() => resolve(true), stabilityThreshold); - }; - mutationObserver = new MutationObserver(waitForIdle); - mutationObserver.observe(document.body, { - childList: true, - attributes: true, - subtree: true, - }); - waitForIdle(); - - setTimeout(() => { - resolve(false); - mutationObserver.disconnect(); - mutationObserver = null; - if (timeoutId) { - clearTimeout(timeoutId); - timeoutId = null; - } - }, timeout); -}); -""" - -JS_GET_SCROLLABLE_PARENT = """ -let element = arguments[0]; -while (element) { - const style = window.getComputedStyle(element); - - // Check if the element is scrollable - if (style.overflow === 'auto' || style.overflow === 'scroll' || - style.overflowX === 'auto' || style.overflowX === 'scroll' || - style.overflowY === 'auto' || style.overflowY === 'scroll') { - - // Check if the element has a scrollable area - if (element.scrollHeight > element.clientHeight || - element.scrollWidth > element.clientWidth) { - return element; - } - } - element = element.parentElement; -} -return null; -""" - -JS_GET_SHADOW_ROOTS = """ -const results = {}; -function traverse(node, xpath) { - if (node.shadowRoot) { - results[xpath] = node.shadowRoot.getHTML(); - } - const countByTag = {}; - for (let child = node.firstChild; child; child = child.nextSibling) { - let tag = child.nodeName.toLowerCase(); - countByTag[tag] = (countByTag[tag] || 0) + 1; - let childXpath = xpath + '/' + tag; - if (countByTag[tag] > 1) { - childXpath += '[' + countByTag[tag] + ']'; - } - if (child.shadowRoot) { - traverse(child.shadowRoot, childXpath + '/'); - } - if (tag === 'iframe') { - try { - traverse(child.contentWindow.document.body, childXpath + '/html/body'); - } catch (e) { - console.warn("iframe access blocked", child, e); - } - } else { - traverse(child, childXpath); - } - } -} -traverse(document.body, '/html/body'); -return results; -""" diff --git a/lavague-sdk/lavague/sdk/base_driver/__init__.py b/lavague-sdk/lavague/sdk/base_driver/__init__.py new file mode 100644 index 00000000..b036ad9e --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/__init__.py @@ -0,0 +1 @@ +from lavague.sdk.base_driver.base import BaseDriver, DOMNode, DriverObservation diff --git a/lavague-sdk/lavague/sdk/base_driver/base.py b/lavague-sdk/lavague/sdk/base_driver/base.py new file mode 100644 index 00000000..5698506e --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/base.py @@ -0,0 +1,277 @@ +import re +from abc import ABC, abstractmethod +from contextlib import contextmanager +from datetime import datetime +from typing import Callable, Dict, List, Optional, Union, TypeVar, Generic +from pydantic import BaseModel +from lavague.sdk.action.navigation import NavigationOutput + +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, + ScrollDirection, +) +from lavague.sdk.base_driver.node import DOMNode + + +class DriverObservation(BaseModel): + html: str + screenshot: bytes + url: str + date: str + tab_info: str + + +T = TypeVar("T", bound=DOMNode, covariant=True) + + +class BaseDriver(ABC, Generic[T]): + @abstractmethod + def init(self) -> None: + """Init the underlying driver""" + pass + + @abstractmethod + def execute(self, action: NavigationOutput) -> None: + """Execute an action""" + pass + + @abstractmethod + def destroy(self) -> None: + """Cleanly destroy the underlying driver""" + pass + + @abstractmethod + def resize_driver(self, width: int, height: int): + """Resize the viewport to a targeted height and width""" + + @abstractmethod + def get_url(self) -> str: + """Get the url of the current page, raise NoPageException if no page is loaded""" + pass + + @abstractmethod + def get(self, url: str) -> None: + """Navigate to the url""" + pass + + @abstractmethod + def back(self) -> None: + """Navigate back, raise CannotBackException if history root is reached""" + pass + + @abstractmethod + def get_html(self) -> str: + """ + Returns the HTML of the current page. + If clean is True, We remove unnecessary tags and attributes from the HTML. + Clean HTMLs are easier to process for the LLM. + """ + pass + + @abstractmethod + def get_tabs(self) -> str: + """Return description of the tabs opened with the current tab being focused. + + Example of output: + Tabs opened: + 0 - Overview - OpenAI API + 1 - [CURRENT] Nos destinations Train - SNCF Connect + """ + pass + + @abstractmethod + def switch_tab(self, tab_id: int) -> None: + """Switch to the tab with the given id""" + pass + + @abstractmethod + def resolve_xpath(self, xpath: str) -> T: + """ + Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary + """ + pass + + @abstractmethod + def get_viewport_size(self) -> dict: + """Return viewport size as {"width": int, "height": int}""" + pass + + @abstractmethod + def get_possible_interactions( + self, + in_viewport=True, + foreground_only=True, + types: List[InteractionType] = [ + InteractionType.CLICK, + InteractionType.TYPE, + InteractionType.HOVER, + ], + ) -> PossibleInteractionsByXpath: + """Get elements that can be interacted with as a dictionary mapped by xpath""" + pass + + @abstractmethod + def scroll( + self, + xpath_anchor: Optional[str] = "/html/body", + direction: ScrollDirection = ScrollDirection.DOWN, + scroll_factor=0.75, + ): + pass + + @abstractmethod + def scroll_into_view(self, xpath: str): + pass + + @abstractmethod + def wait_for_idle(self): + pass + + @abstractmethod + def get_capability(self) -> str: + """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" + pass + + @abstractmethod + def get_screenshot_as_png(self) -> bytes: + pass + + @abstractmethod + def get_shadow_roots(self) -> Dict[str, str]: + """Return a dictionary of shadow roots HTML by xpath""" + pass + + @abstractmethod + def get_nodes(self, xpaths: List[str]) -> List[T]: + pass + + @abstractmethod + def highlight_nodes( + self, xpaths: List[str], color: str = "red", label=False + ) -> Callable: + pass + + @abstractmethod + def switch_frame(self, xpath: str) -> None: + """Switch to the frame with the given xpath, use with care as it changes the state of the driver""" + pass + + @abstractmethod + def switch_parent_frame(self) -> None: + """Switch to the parent frame, use with care as it changes the state of the driver""" + pass + + @contextmanager + def nodes_highlighter(self, nodes: List[str], color: str = "red", label=False): + """Highlight nodes for a context manager""" + remove_highlight = self.highlight_nodes(nodes, color, label) + yield + remove_highlight() + + def get_obs(self) -> DriverObservation: + """Get the current observation of the driver""" + + # We add labels to the scrollable elements + scrollables = self.get_scroll_containers() + with self.nodes_highlighter(scrollables, label=True): + screenshot = self.get_screenshot_as_png() + + url = self.get_url() + html = self.get_html() + tab_info = self.get_tabs() + + return DriverObservation( + html=html, + screenshot=screenshot, + url=url, + date=datetime.now().isoformat(), + tab_info=tab_info, + ) + + def get_in_viewport(self) -> List[str]: + """Get xpath of elements in the viewport""" + interactions = self.get_possible_interactions(in_viewport=True, types=[]) + return list(interactions.keys()) + + def get_scroll_containers(self) -> List[str]: + """Get xpath of elements in the viewport""" + interactions = self.get_possible_interactions(types=[InteractionType.SCROLL]) + return list(interactions.keys()) + + def get_nodes_from_html(self, html: str) -> List[T]: + return self.get_nodes(re.findall(r"xpath=[\"'](.*?)[\"']", html)) + + def highlight_node_from_xpath( + self, xpath: str, color: str = "red", label=False + ) -> Callable: + return self.highlight_nodes([xpath], color, label) + + def highlight_nodes_from_html( + self, html: str, color: str = "blue", label=False + ) -> Callable: + return self.highlight_nodes( + re.findall(r"xpath=[\"'](.*?)[\"']", html), color, label + ) + + def remove_highlight(self): + if hasattr(self, "_highlight_destructors"): + for destructor in self._highlight_destructors: + destructor() + delattr(self, "_highlight_destructors") + + def _add_highlighted_destructors( + self, destructors: Union[List[Callable], Callable] + ) -> Callable: + if not hasattr(self, "_highlight_destructors"): + self._highlight_destructors = [] + if isinstance(destructors, Callable): + self._highlight_destructors.append(destructors) + return destructors + + self._highlight_destructors.extend(destructors) + return lambda: [d() for d in destructors] + + def highlight_interactive_nodes( + self, + *with_interactions: tuple[InteractionType], + color: str = "red", + in_viewport=True, + foreground_only=True, + label=False, + ): + if with_interactions is None or len(with_interactions) == 0: + return self.highlight_nodes( + list( + self.get_possible_interactions( + in_viewport=in_viewport, foreground_only=foreground_only + ).keys() + ), + color, + label, + ) + + return self.highlight_nodes( + [ + xpath + for xpath, interactions in self.get_possible_interactions( + in_viewport=in_viewport, foreground_only=foreground_only + ).items() + if set(interactions) & set(with_interactions) + ], + color, + label, + ) + + def __enter__(self): + self.init() + self.driver_ready = True + return self + + def __exit__(self, *_): + self.destroy() + self.driver_ready = False + + def __del__(self): + if self.driver_ready: + self.__exit__() diff --git a/lavague-sdk/lavague/sdk/base_driver/interaction.py b/lavague-sdk/lavague/sdk/base_driver/interaction.py new file mode 100644 index 00000000..4024c735 --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/interaction.py @@ -0,0 +1,59 @@ +from typing import Dict, Set, List, Tuple +from enum import Enum + + +class InteractionType(Enum): + CLICK = "click" + HOVER = "hover" + SCROLL = "scroll" + TYPE = "type" + + +PossibleInteractionsByXpath = Dict[str, Set[InteractionType]] + + +class ScrollDirection(Enum): + """Enum for the different scroll directions. Value is (x, y, dimension_index)""" + + LEFT = (-1, 0, 0) + RIGHT = (1, 0, 0) + UP = (0, -1, 1) + DOWN = (0, 1, 1) + + def get_scroll_xy( + self, dimension: List[float], scroll_factor: float = 0.75 + ) -> Tuple[int, int]: + size = dimension[self.value[2]] + return ( + round(self.value[0] * size * scroll_factor), + round(self.value[1] * size * scroll_factor), + ) + + def get_page_script(self, scroll_factor: float = 0.75) -> str: + return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" + + def get_script_element_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return arguments[0].scrollTop > 0" + case ScrollDirection.DOWN: + return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" + case ScrollDirection.LEFT: + return "return arguments[0].scrollLeft > 0" + case ScrollDirection.RIGHT: + return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" + + def get_script_page_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return window.scrollY > 0" + case ScrollDirection.DOWN: + return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" + case ScrollDirection.LEFT: + return "return window.scrollX > 0" + case ScrollDirection.RIGHT: + return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" + + @classmethod + def from_string(cls, name: str) -> "ScrollDirection": + return cls[name.upper().strip()] diff --git a/lavague-sdk/lavague/sdk/base_driver/javascript.py b/lavague-sdk/lavague/sdk/base_driver/javascript.py new file mode 100644 index 00000000..79b1a39d --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/javascript.py @@ -0,0 +1,319 @@ +def js_wrap_function_call(fn: str): + return "(function(){" + fn + "})()" + + +JS_SETUP_GET_EVENTS = """ +(function() { + if (window && !window.getEventListeners) { + const targetProto = EventTarget.prototype; + targetProto._addEventListener = Element.prototype.addEventListener; + targetProto.addEventListener = function(a,b,c) { + this._addEventListener(a,b,c); + if(!this.eventListenerList) this.eventListenerList = {}; + if(!this.eventListenerList[a]) this.eventListenerList[a] = []; + this.eventListenerList[a].push(b); + }; + targetProto._removeEventListener = Element.prototype.removeEventListener; + targetProto.removeEventListener = function(a, b, c) { + this._removeEventListener(a, b, c); + if(this.eventListenerList && this.eventListenerList[a]) { + const index = this.eventListenerList[a].indexOf(b); + if (index > -1) { + this.eventListenerList[a].splice(index, 1); + if (!this.eventListenerList[a].length) { + delete this.eventListenerList[a]; + } + } + } + }; + window.getEventListeners = function(e) { + return (e && e.eventListenerList) || []; + } + } +})();""" + +JS_GET_INTERACTIVES = """ +const windowHeight = (window.innerHeight || document.documentElement.clientHeight); +const windowWidth = (window.innerWidth || document.documentElement.clientWidth); + +return (function(inViewport, foregroundOnly, nonInteractives, filterTypes) { + function getInteractions(e) { + const tag = e.tagName.toLowerCase(); + if (!e.checkVisibility() || (e.hasAttribute('disabled') && !nonInteractives) || e.hasAttribute('readonly') + || (tag === 'input' && e.getAttribute('type') === 'hidden') || tag === 'body') { + return []; + } + const rect = e.getBoundingClientRect(); + if (rect.width + rect.height < 5) { + return []; + } + const style = getComputedStyle(e) || {}; + if (style.display === 'none' || style.visibility === 'hidden') { + return []; + } + const events = window && typeof window.getEventListeners === 'function' ? window.getEventListeners(e) : []; + const role = e.getAttribute('role'); + const clickableInputs = ['submit', 'checkbox', 'radio', 'color', 'file', 'image', 'reset']; + function hasEvent(n) { + return events[n]?.length || e.hasAttribute('on' + n); + } + let evts = []; + if (hasEvent('keydown') || hasEvent('keyup') || hasEvent('keypress') || hasEvent('keydown') || hasEvent('input') || e.isContentEditable + || ( + (tag === 'input' || tag === 'textarea' || role === 'searchbox' || role === 'input') + ) && !clickableInputs.includes(e.getAttribute('type')) + ) { + evts.push('TYPE'); + } + if (['a', 'button', 'select'].includes(tag) || ['button', 'checkbox', 'select'].includes(role) + || hasEvent('click') || hasEvent('mousedown') || hasEvent('mouseup') || hasEvent('dblclick') + || style.cursor === 'pointer' + || e.hasAttribute('aria-haspopup') + || (tag === 'input' && clickableInputs.includes(e.getAttribute('type'))) + || (tag === 'label' && document.getElementById(e.getAttribute('for'))) + ) { + evts.push('CLICK'); + } + if ( + (hasEvent('scroll') || hasEvent('wheel') || style.overflow === 'auto' || style.overflow === 'scroll' || style.overflowY === 'auto' || style.overflowY === 'scroll') + && (e.scrollHeight > e.clientHeight || e.scrollWidth > e.clientWidth)) { + evts.push('SCROLL'); + } + if (filterTypes && evts.length) { + evts = evts.filter(t => filterTypes.includes(t)); + } + if (nonInteractives && evts.length === 0) { + evts.push('NONE'); + } + + if (inViewport) { + let rect = e.getBoundingClientRect(); + let iframe = e.ownerDocument.defaultView.frameElement; + while (iframe) { + const iframeRect = iframe.getBoundingClientRect(); + rect = { + top: rect.top + iframeRect.top, + left: rect.left + iframeRect.left, + bottom: rect.bottom + iframeRect.bottom, + right: rect.right + iframeRect.right, + width: rect.width, + height: rect.height + } + iframe = iframe.ownerDocument.defaultView.frameElement; + } + const elemCenter = { + x: Math.round(rect.left + rect.width / 2), + y: Math.round(rect.top + rect.height / 2) + }; + if (elemCenter.x < 0) return []; + if (elemCenter.x > windowWidth) return []; + if (elemCenter.y < 0) return []; + if (elemCenter.y > windowHeight) return []; + if (!foregroundOnly) return evts; // whenever to check for elements above + let pointContainer = document.elementFromPoint(elemCenter.x, elemCenter.y); + iframe = e.ownerDocument.defaultView.frameElement; + while (iframe) { + const iframeRect = iframe.getBoundingClientRect(); + pointContainer = iframe.contentDocument.elementFromPoint( + elemCenter.x - iframeRect.left, + elemCenter.y - iframeRect.top + ); + iframe = iframe.ownerDocument.defaultView.frameElement; + } + do { + if (pointContainer === e) return evts; + if (pointContainer == null) return evts; + } while (pointContainer = pointContainer.parentNode); + return []; + } + return evts; + } + + const results = {}; + function traverse(node, xpath) { + if (node.nodeType === Node.ELEMENT_NODE) { + const interactions = getInteractions(node); + if (interactions.length > 0) { + results[xpath] = interactions; + } + } + const countByTag = {}; + for (let child = node.firstChild; child; child = child.nextSibling) { + let tag = child.nodeName.toLowerCase(); + if (tag.includes(":")) continue; //namespace + let isLocal = ['svg'].includes(tag); + if (isLocal) { + tag = `*[local-name() = '${tag}']`; + } + countByTag[tag] = (countByTag[tag] || 0) + 1; + let childXpath = xpath + '/' + tag; + if (countByTag[tag] > 1) { + childXpath += '[' + countByTag[tag] + ']'; + } + if (tag === 'iframe') { + try { + traverse(child.contentWindow.document.body, childXpath + '/html/body'); + } catch (e) { + console.warn("iframe access blocked", child, e); + } + } else if (!isLocal) { + traverse(child, childXpath); + if (child.shadowRoot) { + traverse(child.shadowRoot, childXpath + '/'); + } + } + } + } + traverse(document.body, '/html/body'); + return results; +})(arguments?.[0], arguments?.[1], arguments?.[2], arguments?.[3]); +""" + +JS_WAIT_DOM_IDLE = """ +return new Promise(resolve => { + const timeout = arguments[0] || 10000; + const stabilityThreshold = arguments[1] || 100; + + let mutationObserver; + let timeoutId = null; + + const waitForIdle = () => { + if (timeoutId) clearTimeout(timeoutId); + timeoutId = setTimeout(() => resolve(true), stabilityThreshold); + }; + mutationObserver = new MutationObserver(waitForIdle); + mutationObserver.observe(document.body, { + childList: true, + attributes: true, + subtree: true, + }); + waitForIdle(); + + setTimeout(() => { + resolve(false); + mutationObserver.disconnect(); + mutationObserver = null; + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = null; + } + }, timeout); +}); +""" + +JS_GET_SCROLLABLE_PARENT = """ +let element = arguments[0]; +while (element) { + const style = window.getComputedStyle(element); + + // Check if the element is scrollable + if (style.overflow === 'auto' || style.overflow === 'scroll' || + style.overflowX === 'auto' || style.overflowX === 'scroll' || + style.overflowY === 'auto' || style.overflowY === 'scroll') { + + // Check if the element has a scrollable area + if (element.scrollHeight > element.clientHeight || + element.scrollWidth > element.clientWidth) { + return element; + } + } + element = element.parentElement; +} +return null; +""" + +JS_GET_SHADOW_ROOTS = """ +const results = {}; +function traverse(node, xpath) { + if (node.shadowRoot) { + results[xpath] = node.shadowRoot.getHTML(); + } + const countByTag = {}; + for (let child = node.firstChild; child; child = child.nextSibling) { + let tag = child.nodeName.toLowerCase(); + countByTag[tag] = (countByTag[tag] || 0) + 1; + let childXpath = xpath + '/' + tag; + if (countByTag[tag] > 1) { + childXpath += '[' + countByTag[tag] + ']'; + } + if (child.shadowRoot) { + traverse(child.shadowRoot, childXpath + '/'); + } + if (tag === 'iframe') { + try { + traverse(child.contentWindow.document.body, childXpath + '/html/body'); + } catch (e) { + console.warn("iframe access blocked", child, e); + } + } else { + traverse(child, childXpath); + } + } +} +traverse(document.body, '/html/body'); +return results; +""" + +ATTACH_MOVE_LISTENER = """ +if (!window._lavague_move_listener) { + window._lavague_move_listener = function() { + const bbs = document.querySelectorAll('.lavague-highlight'); + bbs.forEach(bb => { + const rect = bb._tracking.getBoundingClientRect(); + bb.style.top = rect.top + 'px'; + bb.style.left = rect.left + 'px'; + bb.style.width = rect.width + 'px'; + bb.style.height = rect.height + 'px'; + }); + }; + window.addEventListener('scroll', window._lavague_move_listener); + window.addEventListener('resize', window._lavague_move_listener); +} +""" + +REMOVE_HIGHLIGHT = """ +if (window._lavague_move_listener) { + window.removeEventListener('scroll', window._lavague_move_listener); + window.removeEventListener('resize', window._lavague_move_listener); + delete window._lavague_move_listener; +} +arguments[0].filter(a => a).forEach(a => a.style.cssText = a.dataset.originalStyle || ''); +document.querySelectorAll('.lavague-highlight').forEach(a => a.remove()); +""" + + +def get_highlighter_style(color: str = "red", label: bool = False): + set_style = f""" + const r = a.getBoundingClientRect(); + const bb = document.createElement('div'); + const s = window.getComputedStyle(a); + bb.className = 'lavague-highlight'; + bb.style.position = 'fixed'; + bb.style.top = r.top + 'px'; + bb.style.left = r.left + 'px'; + bb.style.width = r.width + 'px'; + bb.style.height = r.height + 'px'; + bb.style.border = '3px solid {color}'; + bb.style.borderRadius = s.borderRadius; + bb.style['z-index'] = '2147483647'; + bb.style['pointer-events'] = 'none'; + bb._tracking = a; + document.body.appendChild(bb); + """ + + if label: + set_style += """ + const label = document.createElement('div'); + label.style.position = 'absolute'; + label.style.backgroundColor = 'red'; + label.style.color = 'white'; + label.style.padding = '0px 6px 2px 4px'; + label.style.top = '-12px'; + label.style.left = '-12px'; + label.style['font-size'] = '13pt'; + label.style['font-weight'] = 'bold'; + label.style['border-bottom-right-radius'] = '13px'; + label.textContent = i + 1; + bb.appendChild(label); + """ + return set_style diff --git a/lavague-sdk/lavague/sdk/base_driver/node.py b/lavague-sdk/lavague/sdk/base_driver/node.py new file mode 100644 index 00000000..7bd2fa1f --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/node.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar + +from PIL import Image + + +T = TypeVar("T") + + +class DOMNode(ABC, Generic[T]): + @property + @abstractmethod + def element(self) -> T: + pass + + @property + @abstractmethod + def text(self) -> str: + pass + + @property + @abstractmethod + def value(self) -> Optional[str]: + pass + + @property + @abstractmethod + def outer_html(self) -> str: + pass + + @property + @abstractmethod + def inner_html(self) -> str: + pass + + @abstractmethod + def take_screenshot(self) -> Image.Image: + pass + + @abstractmethod + def enter_context(self): + pass + + @abstractmethod + def exit_context(self): + pass + + def __str__(self) -> str: + with self: + return self.outer_html + + def __enter__(self): + self.enter_context() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_context() diff --git a/lavague-sdk/lavague/sdk/client.py b/lavague-sdk/lavague/sdk/client.py index 64688abe..1d0ec531 100644 --- a/lavague-sdk/lavague/sdk/client.py +++ b/lavague-sdk/lavague/sdk/client.py @@ -1,12 +1,29 @@ -from lavague.sdk.trajectory.model import StepCompletion -from lavague.sdk.utilities.config import get_config, is_flag_true, LAVAGUE_API_BASE_URL -from lavague.sdk.action import ActionParser, DEFAULT_PARSER +from io import BytesIO +from typing import Any, Optional, Tuple + +import requests +from lavague.sdk.action import DEFAULT_PARSER, ActionParser from lavague.sdk.trajectory import Trajectory from lavague.sdk.trajectory.controller import TrajectoryController -from typing import Any, Optional +from lavague.sdk.trajectory.model import StepCompletion +from lavague.sdk.utilities.config import LAVAGUE_API_BASE_URL, get_config, is_flag_true from PIL import Image, ImageFile -from io import BytesIO -import requests +from pydantic import BaseModel + + +class RunRequest(BaseModel): + url: str + objective: str + step_by_step: Optional[bool] = False + cloud_driver: Optional[bool] = True + await_completion: Optional[bool] = False + is_public: Optional[bool] = False + viewport_size: Optional[Tuple[int, int]] = None + + +class RunUpdate(BaseModel): + objective: Optional[str] = None + is_public: Optional[bool] = False class LaVague(TrajectoryController): @@ -44,11 +61,19 @@ def request_api( raise ApiException(response.text) return response.content - def run(self, url: str, objective: str, step_by_step=False) -> Trajectory: + def run(self, request: RunRequest) -> Trajectory: content = self.request_api( "/runs", "POST", - {"url": url, "objective": objective, "step_by_step": step_by_step}, + request.model_dump(), + ) + return Trajectory.from_data(content, self.parser, self) + + def update(self, run_id: str, request: RunUpdate) -> Trajectory: + content = self.request_api( + f"/runs/{run_id}", + "PATCH", + request.model_dump(), ) return Trajectory.from_data(content, self.parser, self) diff --git a/lavague-sdk/lavague/sdk/exceptions.py b/lavague-sdk/lavague/sdk/exceptions.py index db902474..d82869af 100644 --- a/lavague-sdk/lavague/sdk/exceptions.py +++ b/lavague-sdk/lavague/sdk/exceptions.py @@ -14,6 +14,11 @@ def __init__(self, message="History root reached, cannot go back"): super().__init__(message) +class NoPageException(NavigationException): + def __init__(self, message="No page loaded"): + super().__init__(message) + + class RetrievalException(NavigationException): pass diff --git a/lavague-sdk/lavague/sdk/trajectory/base.py b/lavague-sdk/lavague/sdk/trajectory/base.py index 30fca6d7..752391bb 100644 --- a/lavague-sdk/lavague/sdk/trajectory/base.py +++ b/lavague-sdk/lavague/sdk/trajectory/base.py @@ -33,12 +33,15 @@ def run_to_completion(self): def stop_run(self): self._controller.stop(self.run_id) - self.status = RunStatus.CANCELLED + self.status = RunStatus.INTERRUPTED + self.error_msg = "Run interrupted by user" def iter_actions(self) -> Iterator[Action]: yield from self.actions while self.is_running: - yield self.next_action() + action = self.next_action() + if action is not None: + yield action @classmethod def from_data( diff --git a/lavague-sdk/lavague/sdk/trajectory/model.py b/lavague-sdk/lavague/sdk/trajectory/model.py index 554cae3f..4a79f143 100644 --- a/lavague-sdk/lavague/sdk/trajectory/model.py +++ b/lavague-sdk/lavague/sdk/trajectory/model.py @@ -21,6 +21,7 @@ class TrajectoryData(BaseModel): viewport_size: Tuple[int, int] status: RunStatus actions: List[SerializeAsAny[Action]] + error_msg: Optional[str] = None def write_to_file(self, file_path: str): json_model = self.model_dump_json(indent=2) diff --git a/lavague-sdk/lavague/sdk/utilities/version_checker.py b/lavague-sdk/lavague/sdk/utilities/version_checker.py index ded0ee2b..8dfd7de1 100644 --- a/lavague-sdk/lavague/sdk/utilities/version_checker.py +++ b/lavague-sdk/lavague/sdk/utilities/version_checker.py @@ -41,6 +41,8 @@ def check_latest_version(): url = "https://pypi.org/pypi/lavague-sdk/json" response = requests.get(url) data = response.json() + if data.get("message") == "Not Found": + return latest_version = data["info"]["version"] if compare_versions(package_version, latest_version) < 0: warnings.warn(