diff --git a/lavague-sdk/lavague/sdk/client.py b/lavague-sdk/lavague/sdk/client.py index f765ee06..0a8d34ac 100644 --- a/lavague-sdk/lavague/sdk/client.py +++ b/lavague-sdk/lavague/sdk/client.py @@ -2,10 +2,10 @@ from typing import Any, Optional, Tuple import requests -from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction, Action +from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction from lavague.sdk.trajectory import Trajectory from lavague.sdk.trajectory.controller import TrajectoryController -from lavague.sdk.trajectory.model import StepCompletion +from lavague.sdk.trajectory.model import StepCompletion, StepKnowledge from lavague.sdk.utilities.config import LAVAGUE_API_BASE_URL, get_config, is_flag_true from PIL import Image, ImageFile from pydantic import BaseModel @@ -103,7 +103,7 @@ def generate_action(self, run_id: str, instruction: Instruction) -> StepCompleti ) return StepCompletion.from_data(content) - def execute_action(self, run_id: str, action: StepCompletion) -> StepCompletion: + def execute_action(self, run_id: str, action: StepKnowledge) -> StepCompletion: content = self.request_api( f"/runs/{run_id}/step/execution", "POST", diff --git a/lavague-sdk/lavague/sdk/trajectory/model.py b/lavague-sdk/lavague/sdk/trajectory/model.py index caa64af5..35cd39ef 100644 --- a/lavague-sdk/lavague/sdk/trajectory/model.py +++ b/lavague-sdk/lavague/sdk/trajectory/model.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Tuple, Optional from pydantic import BaseModel, SerializeAsAny -from lavague.sdk.action import Action, ActionParser +from lavague.sdk.action import Action, ActionParser, Instruction from lavague.sdk.action.base import DEFAULT_PARSER from pydantic import model_validator from pydantic_core import from_json @@ -61,10 +61,8 @@ def write_to_file(self, file_path: str): file.write(json_model) -class StepCompletion(BaseModel): - run_status: RunStatus +class ActionWrapper(BaseModel): action: Optional[Action] - run_mode: RunMode @classmethod def from_data( @@ -84,3 +82,25 @@ def from_dict( action = data.get("action") action = parser.parse(action) if action else None return cls.model_validate({**data, "action": action}) + + @model_validator(mode="before") + @classmethod + def deserialize_action(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "action" in values: + action_data = values["action"] + if not isinstance(action_data, Action) and "action_type" in action_data: + action_class = DEFAULT_PARSER.engine_action_builders.get( + action_data["action_type"], Action + ) + deserialized_action = action_class.parse(action_data) + values["action"] = deserialized_action + return values + + +class StepCompletion(ActionWrapper): + run_status: RunStatus + run_mode: RunMode + + +class StepKnowledge(ActionWrapper): + instruction: Instruction