diff --git a/.github/new-code-check.py b/.github/new-code-check.py deleted file mode 100644 index 8e10193c..00000000 --- a/.github/new-code-check.py +++ /dev/null @@ -1,78 +0,0 @@ -from lavague.drivers.selenium import SeleniumDriver -from lavague.core import ActionEngine, WorldModel -from lavague.core.agents import WebAgent - -selenium_driver = SeleniumDriver() -action_engine = ActionEngine(selenium_driver) -world_model = WorldModel() - -agent = WebAgent(world_model, action_engine) - -url = "https://huggingface.co" -objective = "Provide the code to use Falcon 11B" - -agent.get(url) -output = agent.run(objective, display=False) - -result = output[-1] - -expected_output = """from transformers import AutoTokenizer, AutoModelForCausalLM -import transformers -import torch - -model = "tiiuae/falcon-11B" - -tokenizer = AutoTokenizer.from_pretrained(model) -pipeline = transformers.pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - torch_dtype=torch.bfloat16, - device_map="auto", -) -sequences = pipeline( - "Can you explain the concepts of Quantum Computing?", - max_length=200, - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, -) -for seq in sequences: - print(f"Result: {seq['generated_text']}")""" - -expected_output_2 = """from transformers import AutoTokenizer, AutoModelForCausalLM -import transformers -import torch - -model = "tiiuae/falcon-11B" - -tokenizer = AutoTokenizer.from_pretrained(model) -pipeline = transformers.pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - torch_dtype=torch.bfloat16, -) -sequences = pipeline( - "Can you explain the concepts of Quantum Computing?", - max_length=200, - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, -) -for seq in sequences: - print(f"Result: {seq['generated_text']}")""" - -# Remove all whitespace characters from both strings before comparison -expected_output_stripped = "".join(expected_output.strip().split()) -expected_output_2_stripped = "".join(expected_output_2.strip().split()) -result_stripped = "".join(result.strip().split()) - -# Check if the stripped expected output is contained within the stripped result -assert ( - expected_output_stripped in result_stripped - or expected_output_2_stripped in result_stripped -), f"Output does not match expected:\nExpected: {expected_output}\nActual: {result}" -print("Output matches expected.") diff --git a/.github/workflows/docs-code.yaml b/.github/workflows/docs-code.yaml index 12528291..0be5e01f 100644 --- a/.github/workflows/docs-code.yaml +++ b/.github/workflows/docs-code.yaml @@ -9,7 +9,6 @@ on: - 'lavague-gradio/**' - 'pyproject.toml' - '.github/workflows/docs-code.yaml' - - '.github/new-code-check.py' workflow_dispatch: jobs: diff --git a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py index 3a96a3ea..78bbf337 100644 --- a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py +++ b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py @@ -3,7 +3,6 @@ import os from PIL import Image from typing import Callable, Optional, Any, Mapping, Dict, List -from lavague.core.utilities.format_utils import extract_code_from_funct from playwright.sync_api import Page, Locator from lavague.sdk.base_driver import ( BaseDriver, @@ -12,7 +11,7 @@ PossibleInteractionsByXpath, InteractionType, ) -from lavague.core.exceptions import ( +from lavague.sdk.exceptions import ( NoElementException, AmbiguousException, ) @@ -82,39 +81,6 @@ def default_init_code(self) -> Page: self.resize_driver(self.width, self.height) return self.page - def code_for_init(self) -> str: - init_lines = extract_code_from_funct(self.init_function) - code_lines = [ - "from playwright.sync_api import sync_playwright", - "", - ] - ignore_next = 0 - keep_else = False - start = False - for line in init_lines: - if "__enter__()" in line: - start = True - elif not start: - continue - if "self.headless" in line: - line = line.replace("self.headless", str(self.headless)) - if "self.user_data_dir" in line: - line = line.replace("self.user_data_dir", f'"{self.user_data_dir}"') - if "if" in line: - if self.user_data_dir is not None: - ignore_next = 1 - keep_else = True - elif "else" in line: - if not keep_else: - ignore_next = 3 - elif ignore_next <= 0: - if "self" not in line: - code_lines.append(line.strip()) - else: - ignore_next -= 1 - code_lines.append(self.code_for_resize(self.width, self.height)) - return "\n".join(code_lines) + "\n" - def get_driver(self) -> Page: return self.page diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py index d40e5a2d..bc14ddb1 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py @@ -24,7 +24,7 @@ InteractionType, DOMNode, ) -from lavague.core.exceptions import ( +from lavague.sdk.exceptions import ( CannotBackException, NoElementException, AmbiguousException, @@ -32,10 +32,6 @@ from PIL import Image from io import BytesIO from selenium.webdriver.chrome.options import Options -from lavague.core.utilities.format_utils import ( - extract_code_from_funct, - quote_numeric_yaml_values, -) from selenium.webdriver.common.action_chains import ActionChains import time import yaml @@ -127,28 +123,6 @@ def default_init_code(self) -> Any: self.resize_driver(self.width, self.height) return self.driver - def code_for_init(self) -> str: - init_lines = extract_code_from_funct(self.init_function) - code_lines = [] - keep_next = True - for line in init_lines: - if "--user-data-dir" in line: - line = line.replace( - f"{{self.user_data_dir}}", f'"{self.user_data_dir}"' - ) - if "if" in line: - if ("headless" in line and not self.headless) or ( - "user_data_dir" in line and self.user_data_dir is None - ): - keep_next = False - elif keep_next: - if "self" not in line: - code_lines.append(line.strip()) - else: - keep_next = True - code_lines.append(self.code_for_resize(self.width, self.height)) - return "\n".join(code_lines) + "\n" - def __enter__(self): return self diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/listener.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/listener.py index 0d64621d..8a9acf9a 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/listener.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/listener.py @@ -1,5 +1,108 @@ from selenium.webdriver.remote.webdriver import WebDriver -from lavague.core.listener import EventListener as BaseEventListener +from selenium.common.exceptions import TimeoutException +from typing import Callable, Any, List, Dict, Literal, Optional +import threading + + +class BaseEventListener: + """ + A utility class for listening to DOM events using a JS executor. + + This class provides methods to listen for specific user actions (like clicks) on web elements + identified by their XPath and execute a callback function upon detection. + """ + + def __init__(self, executor: Callable[[str, bool, Any], Any]): + self.executor = executor + self._destructors = [] + + def listen_next_action( + self, xpaths: Optional[List[str]] = None, no_timeout=False, prevent_action=False + ) -> Dict[Literal["eventType", "key", "button", "xpath", "element"], Any]: + """ + Listens for the next user action (such as a click) on elements that match the given xpaths, and prevent default behaviour. + This method blocks until an action is detected or a timeout occurs, unless no_timeout is set to True. + If xpaths is None all actions will be listened for. + Returns a dictionary containing information about the detected action. + """ + try: + event_data = self.executor(JS_LISTEN_ACTION, prevent_action, xpaths) + return event_data + except TimeoutException as e: + if no_timeout: + return self.listen_next_action( + xpaths=xpaths, no_timeout=no_timeout, prevent_action=prevent_action + ) + raise e + + def listen_next_action_async( + self, + callback: Callable, + xpaths: Optional[List[str]] = None, + no_timeout=False, + prevent_action=False, + ): + """ + Same as listen_next_action but async with a callback. + """ + thread = threading.Thread( + target=lambda: callback( + self.listen_next_action( + xpaths=xpaths, no_timeout=no_timeout, prevent_action=prevent_action + ) + ) + ) + thread.start() + + def listen( + self, + callback: Callable[[Any], Any], + xpaths: Optional[List[str]] = None, + no_timeout=False, + prevent_action=False, + ) -> Callable: + """ + Listen for user actions and execute the provided callback until the listener is stopped. + Returns a destructor function that can be used to stop listening for events. + """ + active = True + + def destructor(): + nonlocal active + active = False + + def loop(): + while active: + try: + next = self.listen_next_action( + xpaths=xpaths, + no_timeout=no_timeout, + prevent_action=prevent_action, + ) + if active: + callback(next) + except TimeoutException: + continue + except Exception as e: + if active: + raise e + self._destructors.remove(destructor) + + thread = threading.Thread(target=loop) + thread.start() + + self._destructors.append(destructor) + return destructor + + def stop(self): + for destructor in self._destructors: + destructor() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() class EventListener(BaseEventListener): @@ -9,3 +112,35 @@ def __init__(self, driver: WebDriver): script, prevent_action, xpaths ) ) + + +JS_LISTEN_ACTION = """ +function getElementXPath(element) { + if (element === document.body) return '/html/body'; + if (!element.parentNode) return ''; + for (let i = 0, ix = 0; i < element.parentNode.childNodes.length; i++) { + const sibling = element.parentNode.childNodes[i]; + if (sibling === element) { + const tagName = element.tagName.toLowerCase(); + const position = ix ? `[${ix + 1}]` : ''; + const parentPath = getElementXPath(element.parentNode); + return `${parentPath}/${tagName}${position}`; + } + if (sibling.nodeType === 1 && sibling.tagName === element.tagName) ix++; + } +} +const preventAction = arguments?.[0]; +const listenFor = Array.isArray(arguments?.[1]) ? arguments[1] : null; +const callback = arguments[arguments.length - 1]; +function handleEvent(event) { + const xpath = getElementXPath(event.target); + if (listenFor && !listenFor.includes(xpath)) return true; + if (preventAction) { + event.preventDefault(); + event.stopPropagation(); + } + callback({eventType: event.type, key: event.key, button: event.button, xpath, element: event.target}); + return false; +} +document.addEventListener('click', handleEvent, {capture: true, once: true}); +""" diff --git a/lavague-sdk/lavague/sdk/base_driver.py b/lavague-sdk/lavague/sdk/base_driver.py index f7f76575..f8073309 100644 --- a/lavague-sdk/lavague/sdk/base_driver.py +++ b/lavague-sdk/lavague/sdk/base_driver.py @@ -4,10 +4,6 @@ import re from typing import Any, Callable, Optional, Mapping, Dict, Set, List, Tuple, Union from abc import ABC, abstractmethod -from lavague.core.utilities.format_utils import ( - extract_code_from_funct, - extract_imports_from_lines, -) from enum import Enum from datetime import datetime import hashlib @@ -36,10 +32,6 @@ def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any] # Flag to check if the page has been previously scanned to avoid erasing screenshots from previous scan self.previously_scanned = False - # extract import lines for later exec of generated code - init_lines = extract_code_from_funct(self.init_function) - self.import_lines = extract_imports_from_lines(init_lines) - if url is not None: self.get(url) @@ -52,11 +44,6 @@ def default_init_code(self) -> Any: """Init the driver, with the imports, since it will be pasted to the beginning of the output""" pass - @abstractmethod - def code_for_init(self) -> str: - """Extract the code to past to the begining of the final script from the init code""" - pass - @abstractmethod def destroy(self) -> None: """Cleanly destroy the underlying driver""" diff --git a/lavague-sdk/lavague/sdk/exceptions.py b/lavague-sdk/lavague/sdk/exceptions.py new file mode 100644 index 00000000..db902474 --- /dev/null +++ b/lavague-sdk/lavague/sdk/exceptions.py @@ -0,0 +1,38 @@ +from typing import Optional + + +class NavigationException(Exception): + pass + + +class ExtractorException(Exception): + pass + + +class CannotBackException(NavigationException): + def __init__(self, message="History root reached, cannot go back"): + super().__init__(message) + + +class RetrievalException(NavigationException): + pass + + +class NoElementException(RetrievalException): + def __init__(self, message="No element found"): + super().__init__(message) + + +class AmbiguousException(RetrievalException): + def __init__(self, message="Multiple elements could match"): + super().__init__(message) + + +class HallucinatedException(RetrievalException): + def __init__(self, xpath: str, message: Optional[str] = None): + super().__init__(message or f"Element was hallucinated: {xpath}") + + +class ElementOutOfContextException(RetrievalException): + def __init__(self, xpath: str, message: Optional[str] = None): + super().__init__(message or f"Element exists but was not in context: {xpath}")