Skip to content

Commit

Permalink
feat: add knowledge
Browse files Browse the repository at this point in the history
  • Loading branch information
adeprez committed Oct 11, 2024
1 parent ca5594f commit a8e0c57
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
6 changes: 3 additions & 3 deletions lavague-sdk/lavague/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any, Optional, Tuple

import requests
from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction, Action
from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction
from lavague.sdk.trajectory import Trajectory
from lavague.sdk.trajectory.controller import TrajectoryController
from lavague.sdk.trajectory.model import StepCompletion
from lavague.sdk.trajectory.model import StepCompletion, StepKnowledge
from lavague.sdk.utilities.config import LAVAGUE_API_BASE_URL, get_config, is_flag_true
from PIL import Image, ImageFile
from pydantic import BaseModel
Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_action(self, run_id: str, instruction: Instruction) -> StepCompleti
)
return StepCompletion.from_data(content)

def execute_action(self, run_id: str, action: StepCompletion) -> StepCompletion:
def execute_action(self, run_id: str, action: StepKnowledge) -> StepCompletion:
content = self.request_api(
f"/runs/{run_id}/step/execution",
"POST",
Expand Down
28 changes: 24 additions & 4 deletions lavague-sdk/lavague/sdk/trajectory/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Tuple, Optional
from pydantic import BaseModel, SerializeAsAny
from lavague.sdk.action import Action, ActionParser
from lavague.sdk.action import Action, ActionParser, Instruction
from lavague.sdk.action.base import DEFAULT_PARSER
from pydantic import model_validator
from pydantic_core import from_json
Expand Down Expand Up @@ -61,10 +61,8 @@ def write_to_file(self, file_path: str):
file.write(json_model)


class StepCompletion(BaseModel):
run_status: RunStatus
class ActionWrapper(BaseModel):
action: Optional[Action]
run_mode: RunMode

@classmethod
def from_data(
Expand All @@ -84,3 +82,25 @@ def from_dict(
action = data.get("action")
action = parser.parse(action) if action else None
return cls.model_validate({**data, "action": action})

@model_validator(mode="before")
@classmethod
def deserialize_action(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "action" in values:
action_data = values["action"]
if not isinstance(action_data, Action) and "action_type" in action_data:
action_class = DEFAULT_PARSER.engine_action_builders.get(
action_data["action_type"], Action
)
deserialized_action = action_class.parse(action_data)
values["action"] = deserialized_action
return values


class StepCompletion(ActionWrapper):
run_status: RunStatus
run_mode: RunMode


class StepKnowledge(ActionWrapper):
instruction: Instruction

0 comments on commit a8e0c57

Please sign in to comment.