Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add endpoints to compute and execute custom actions #615

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
NoPageException,
)

from lavague.sdk.action.navigation import WebNavigationAction, NavigationCommand
from lavague.sdk.action import ActionStatus
from selenium.common.exceptions import (
NoSuchElementException,
TimeoutException,
Expand Down
2 changes: 2 additions & 0 deletions lavague-sdk/lavague/sdk/action/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from lavague.sdk.action.base import (
Action,
Instruction,
EngineType,
ActionType,
ActionStatus,
ActionParser,
Expand Down
13 changes: 13 additions & 0 deletions lavague-sdk/lavague/sdk/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ class ActionType(str, Enum):
EXTRACTION = "web_extraction"


class EngineType(str, Enum):
NAVIGATION = "Navigation Engine"
EXTRACTION = "Element Extraction Engine"
CONTROLS = "Navigation Controls"
COMPLETE = "COMPLETE"


T = TypeVar("T")


Expand Down Expand Up @@ -55,6 +62,12 @@ def parse(self, action_dict: Dict) -> Action:
return Action.parse(action_dict)


class Instruction(BaseModel):
chain_of_toughts: str
engine: EngineType
engine_instruction: str


class UnhandledTypeException(Exception):
pass

Expand Down
1 change: 1 addition & 0 deletions lavague-sdk/lavague/sdk/base_driver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def execute(self, action: NavigationOutput) -> None:
raise NotImplementedError(
f"Action {action.navigation_command} not implemented"
)
self.wait_for_idle()

@abstractmethod
def destroy(self) -> None:
Expand Down
29 changes: 26 additions & 3 deletions lavague-sdk/lavague/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any, Optional, Tuple

import requests
from lavague.sdk.action import DEFAULT_PARSER, ActionParser
from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction
from lavague.sdk.trajectory import Trajectory
from lavague.sdk.trajectory.controller import TrajectoryController
from lavague.sdk.trajectory.model import StepCompletion
from lavague.sdk.trajectory.model import StepCompletion, StepKnowledge
from lavague.sdk.utilities.config import LAVAGUE_API_BASE_URL, get_config, is_flag_true
from PIL import Image, ImageFile
from pydantic import BaseModel
Expand Down Expand Up @@ -86,7 +86,30 @@ def next_step(self, run_id: str) -> StepCompletion:
f"/runs/{run_id}/step",
"POST",
)
return StepCompletion.model_validate_json(content)
return StepCompletion.from_data(content)

def generate_instruction(self, run_id: str) -> Instruction:
content = self.request_api(
f"/runs/{run_id}/step/instruction",
"POST",
)
return Instruction.model_validate_json(content)

def generate_action(self, run_id: str, instruction: Instruction) -> StepCompletion:
content = self.request_api(
f"/runs/{run_id}/step/action",
"POST",
instruction.model_dump(),
)
return StepCompletion.from_data(content)

def execute_action(self, run_id: str, action: StepKnowledge) -> StepCompletion:
content = self.request_api(
f"/runs/{run_id}/step/execution",
"POST",
action.model_dump(),
)
return StepCompletion.from_data(content)

def stop(self, run_id: str) -> None:
self.request_api(
Expand Down
53 changes: 50 additions & 3 deletions lavague-sdk/lavague/sdk/trajectory/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import Enum
from typing import Any, Dict, List, Tuple, Optional
from pydantic import BaseModel, SerializeAsAny
from lavague.sdk.action import Action
from lavague.sdk.action import Action, ActionParser, Instruction
from lavague.sdk.action.base import DEFAULT_PARSER
from pydantic import model_validator
from pydantic_core import from_json


class RunStatus(str, Enum):
Expand Down Expand Up @@ -60,7 +61,53 @@ def write_to_file(self, file_path: str):
file.write(json_model)


class StepCompletion(BaseModel):
run_status: RunStatus
class ActionWrapper(BaseModel):
action: Optional[Action]

@classmethod
def from_data(
cls,
data: str | bytes | bytearray,
parser: ActionParser = DEFAULT_PARSER,
):
obj = from_json(data)
return cls.from_dict(obj, parser)

@classmethod
def from_dict(
cls,
data: Dict,
parser: ActionParser = DEFAULT_PARSER,
):
action = data.get("action")
action = parser.parse(action) if action else None
return cls.model_validate({**data, "action": action})

@model_validator(mode="before")
@classmethod
def deserialize_action(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "action" in values:
action_data = values["action"]
if (
action_data
and not isinstance(action_data, Action)
and "action_type" in action_data
):
action_class = DEFAULT_PARSER.engine_action_builders.get(
action_data["action_type"], Action
)
deserialized_action = action_class.parse(action_data)
values["action"] = deserialized_action
return values


class StepKnowledge(ActionWrapper):
instruction: Instruction


class StepCompletion(ActionWrapper):
run_status: RunStatus
run_mode: RunMode

def to_knowledge(self, instruction: Instruction) -> StepKnowledge:
return StepKnowledge(instruction=instruction, action=self.action)
Loading