Skip to content

Commit

Permalink
feat: delegate trajectory export to exporters
Browse files Browse the repository at this point in the history
  • Loading branch information
adeprez committed Sep 19, 2024
1 parent 52bcd8e commit b2cfdb0
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 163 deletions.
3 changes: 1 addition & 2 deletions lavague-core/lavague/action/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
ActionParser,
DEFAULT_PARSER,
UnhandledTypeException,
ActionTranslator,
)

from lavague.action.navigation import NavigationAction

DEFAULT_PARSER.register("navigation", NavigationAction)
DEFAULT_PARSER.register("web_navigation", NavigationAction)
18 changes: 5 additions & 13 deletions lavague-core/lavague/action/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Type, Optional, Callable, TypeVar, Self
from typing import Dict, Type
from pydantic import BaseModel, validate_call
from enum import Enum

Expand All @@ -11,18 +11,16 @@ class ActionStatus(Enum):
class Action(BaseModel):
"""Action performed by the agent."""

engine: str
step_id: str
action_type: str
action: str
url: str
status: ActionStatus

@classmethod
def parse(cls, action_dict: Dict) -> "Action":
return cls(**action_dict)

@classmethod
def add_translator(cls, name: str, translator: "ActionTranslator[Self]"):
setattr(cls, name, translator)


class ActionParser(BaseModel):
engine_action_builders: Dict[str, Type[Action]]
Expand All @@ -39,7 +37,7 @@ def unregister(self, engine: str):
del self.engine_action_builders[engine]

def parse(self, action_dict: Dict) -> Action:
engine = action_dict.get("engine", "")
engine = action_dict.get("action_type", "")
target_type: Type[Action] = self.engine_action_builders.get(engine, Action)
try:
return target_type.parse(action_dict)
Expand All @@ -51,10 +49,4 @@ class UnhandledTypeException(Exception):
pass


T = TypeVar("T", bound=Action)


ActionTranslator = Callable[[T], Optional[str]]


DEFAULT_PARSER = ActionParser()
76 changes: 1 addition & 75 deletions lavague-core/lavague/action/navigation.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,9 @@
from lavague.action import Action
from typing import ClassVar, Dict, Type, Optional, TypeVar

T = TypeVar("T", bound="NavigationAction")
from typing import Optional


class NavigationAction(Action):
"""Navigation action performed by the agent."""

subtypes: ClassVar[Dict[str, Type["NavigationAction"]]] = {}

xpath: str
value: Optional[str] = None

@classmethod
def parse(cls, action_dict: Dict) -> "NavigationAction":
action_name = action_dict.get("action", "")
target_type = cls.subtypes.get(action_name, NavigationAction)
return target_type(**action_dict)

@classmethod
def register_subtype(cls, subtype: str, action: Type[T]):
cls.subtypes[subtype] = action
return cls


def register_navigation(name: str):
def wrapper(cls: Type[T]) -> Type[T]:
NavigationAction.register_subtype(name, cls)
return cls

return wrapper


class NavigationWithValueAction(NavigationAction):
"""Navigation action performed by the agent with a value."""

value: str


@register_navigation("click")
class ClickAction(NavigationAction):
pass


@register_navigation("hover")
class HoverAction(NavigationAction):
pass


@register_navigation("setValue")
class SetValueAction(NavigationWithValueAction):
pass


@register_navigation("setValueAndEnter")
class SetValueAndEnterAction(SetValueAction):
pass


@register_navigation("dropdownSelect")
class DropdownSelectAction(NavigationWithValueAction):
pass


@register_navigation("scroll_down")
class ScrollDownAction(NavigationAction):
pass


@register_navigation("scroll_up")
class ScrollUpAction(NavigationAction):
pass


@register_navigation("back")
class BackAction(NavigationAction):
pass


@register_navigation("switch_tab")
class SwitchTabAction(NavigationAction):
pass
2 changes: 1 addition & 1 deletion lavague-core/lavague/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from lavague.driver.base import BaseDriver
from lavague.driver.base import BaseDriver
2 changes: 1 addition & 1 deletion lavague-core/lavague/driver/javascript.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,4 @@
}
traverse(document.body, '/html/body');
return results;
"""
"""
9 changes: 1 addition & 8 deletions lavague-core/lavague/exporter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
from lavague.exporter.base import (
TrajectoryExporter,
ActionTranslator,
ActionWrapper,
method_action_translator,
wrap_action_translator,
translate_action,
)
from lavague.exporter.base import TrajectoryExporter
69 changes: 9 additions & 60 deletions lavague-core/lavague/exporter/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from lavague.trajectory import Trajectory
from lavague.action import Action, ActionTranslator
from typing import Optional, Self, Protocol, TypeVar, Callable
import copy


class ActionWrapper(Protocol):
def __call__(self, action: Action, code: str) -> str: ...
from lavague.action import Action
from typing import Optional, Callable, Self


class TrajectoryExporter:
Expand All @@ -20,12 +15,12 @@ def generate_teardown(self, trajectory: Trajectory) -> Optional[str]:
def on_missing_action(self, action: Action, method_name: str) -> None:
"""Generate code for missing action"""
raise NotImplementedError(
f"Action {action.action} translator is missing, please add a '{method_name}' method in {self.__class__.__name__} or attach it with {self.__class__.__name__}.add_action_translator('{action.engine}_{action.action}', my_translator_function)"
f"Action {action.action} translator is missing, please add a '{method_name}' method in {self.__class__.__name__} or attach it with {self.__class__.__name__}.add_action_translator('{action.action_type}', '{action.action}', my_translator_function)"
)

def translate_action(self, action: Action) -> Optional[str]:
"""Translate a single action to target framework code"""
method_name = f"translate_{action.engine}_{action.action}"
method_name = f"translate_{action.action_type}_{action.action}"

if hasattr(self, method_name):
return getattr(self, method_name)(action)
Expand All @@ -42,63 +37,17 @@ def export(self, trajectory: Trajectory) -> str:
actions = [self.translate_action(action) for action in trajectory.actions]
return self.merge_code(setup, *actions, teardown)

def __call__(self, trajectory: Trajectory) -> str:
return self.export(trajectory)

def export_to_file(self, trajectory: Trajectory, file_path: str):
exported = self.export(trajectory)
with open(file_path, "w", encoding="utf-8") as file:
file.write(exported)

def with_wrapper(self, wrapper: ActionWrapper, clone=True) -> Self:
instance = copy.copy(self) if clone else self
instance.translate_action = lambda action: wrap_action_translator(
self.translate_action, wrapper
)(action)
return instance

@classmethod
def add_action_translator(
cls, name: str, translator: Callable[[Self, Action], Optional[str]]
cls,
action_type: str,
action: str,
translator: Callable[[Self, Action], Optional[str]],
) -> None:
"""Add a new action translator to the exporter"""
setattr(cls, f"translate_{name}", translator)

@classmethod
def from_translator(
cls, action_translator: ActionTranslator
) -> "TrajectoryExporter":
class DynamicExporter(cls):
def translate_action(self, action: Action) -> Optional[str]:
return action_translator(action)

return DynamicExporter()

@classmethod
def from_method(cls, method_name: str) -> "TrajectoryExporter":
return cls.from_translator(method_action_translator(method_name))


def translate_action(action: Action, method_name: str) -> Optional[str]:
return getattr(action, method_name)() if hasattr(action, method_name) else None


def method_action_translator(name: str) -> ActionTranslator[Action]:
def wrapper(action: Action) -> Optional[str]:
return translate_action(action, name)

return wrapper


T = TypeVar("T", bound=Action)


def wrap_action_translator(
action_translator: ActionTranslator[T],
wrapper: ActionWrapper,
) -> ActionTranslator[T]:
def wrapped(action: T) -> Optional[str]:
code = action_translator(action)
return wrapper(action, code) if code else None

return wrapped
setattr(cls, f"translate_{action_type}_{action}", translator)
2 changes: 1 addition & 1 deletion lavague-core/lavague/trajectory/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from lavague.trajectory.base import Trajectory, TrajectoryStatus
from lavague.trajectory.base import Trajectory, TrajectoryStatus
10 changes: 8 additions & 2 deletions lavague-core/lavague/trajectory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class TrajectoryStatus(Enum):
class Trajectory(BaseModel):
"""Observable trajectory of web interactions towards an objective."""

url: str
run_id: str
start_url: str
objective: str
status: TrajectoryStatus
output: Optional[str]
Expand All @@ -33,10 +34,15 @@ def from_data(
obj["actions"] = [parser.parse(action) for action in obj.get("actions", [])]
return cls.model_validate(obj)

@classmethod
def from_dict(cls, data: dict, parser: ActionParser = DEFAULT_PARSER):
data["actions"] = [parser.parse(action) for action in data.get("actions", [])]
return cls.model_validate(data)

@classmethod
def from_file(
cls, file_path: str, parser: ActionParser = DEFAULT_PARSER, encoding="utf-8"
):
with open(file_path, "r", encoding=encoding) as file:
content = file.read()
return cls.from_data(content, parser)
return cls.from_data(content, parser)

0 comments on commit b2cfdb0

Please sign in to comment.