Skip to content

Commit

Permalink
Add type checking improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 2, 2024
1 parent 87da356 commit a420fd1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 23 deletions.
10 changes: 6 additions & 4 deletions smarts/core/actor_capture_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import annotations

import warnings
from collections import namedtuple
from dataclasses import replace
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from smarts.core.actor import ActorState
from smarts.core.plan import Plan
from smarts.core.vehicle import Vehicle
from smarts.sstudio.types import ConditionRequires

if TYPE_CHECKING:
from smarts.core.actor import ActorState
from smarts.core.vehicle import Vehicle


class ActorCaptureManager:
"""The base for managers that handle transition of control of actors."""
Expand Down
49 changes: 30 additions & 19 deletions smarts/core/vehicle_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
shadower
shadowers
"""
from __future__ import annotations

import logging
from copy import copy, deepcopy
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Dict,
FrozenSet,
Iterator,
Expand All @@ -53,7 +56,15 @@
from .controllers import ControllerState
from .road_map import RoadMap
from .sensors import SensorState
from .vehicle import Vehicle, VehicleState
from .vehicle import Vehicle

if TYPE_CHECKING:
from smarts.core.agent_interface import AgentInterface
from smarts.core.plan import Plan
from smarts.core.smarts import SMARTS
from smarts.core.renderer_base import RendererBase

from .vehicle import VehicleState

VEHICLE_INDEX_ID_LENGTH = 128

Expand Down Expand Up @@ -375,11 +386,11 @@ def teardown(self, renderer):
@clear_cache
def start_agent_observation(
self,
sim,
sim: SMARTS,
vehicle_id,
agent_id,
agent_interface,
plan,
agent_interface: AgentInterface,
plan: Plan,
boid=False,
initialize_sensors=True,
):
Expand Down Expand Up @@ -422,13 +433,13 @@ def start_agent_observation(
@clear_cache
def switch_control_to_agent(
self,
sim,
sim: SMARTS,
vehicle_id,
agent_id,
boid=False,
hijacking=False,
recreate=False,
agent_interface=None,
agent_interface: Optional[AgentInterface]=None,
):
"""Give control of the specified vehicle to the specified agent.
Args:
Expand Down Expand Up @@ -540,7 +551,7 @@ def stop_agent_observation(self, vehicle_id) -> Vehicle:

@clear_cache
def relinquish_agent_control(
self, sim, vehicle_id: str, road_map
self, sim: SMARTS, vehicle_id: str, road_map: RoadMap
) -> Tuple[VehicleState, List[str]]:
"""Give control of the vehicle back to its original controller."""
self._log.debug(f"Relinquishing agent control v_id={vehicle_id}")
Expand Down Expand Up @@ -580,7 +591,7 @@ def relinquish_agent_control(
return vehicle.state, route

@clear_cache
def attach_sensors_to_vehicle(self, sim, vehicle_id, agent_interface, plan):
def attach_sensors_to_vehicle(self, sim: SMARTS, vehicle_id, agent_interface: AgentInterface, plan: Plan):
"""Attach sensors as per the agent interface requirements to the specified vehicle."""
vehicle_id = _2id(vehicle_id)

Expand All @@ -603,7 +614,7 @@ def attach_sensors_to_vehicle(self, sim, vehicle_id, agent_interface, plan):
)

def _switch_control_to_agent_recreate(
self, sim, vehicle_id, agent_id, boid, hijacking
self, sim: SMARTS, vehicle_id, agent_id, boid: bool, hijacking: bool
):
# XXX: vehicle_id and agent_id are already fixed-length as this is an internal
# method.
Expand Down Expand Up @@ -677,10 +688,10 @@ def _switch_control_to_agent_recreate(

def build_agent_vehicle(
self,
sim,
sim: SMARTS,
agent_id,
agent_interface,
plan,
agent_interface: AgentInterface,
plan: Plan,
trainable: bool,
initial_speed: Optional[float] = None,
boid: bool = False,
Expand Down Expand Up @@ -724,12 +735,12 @@ def build_agent_vehicle(
@clear_cache
def _enfranchise_agent(
self,
sim,
sim: SMARTS,
agent_id,
agent_interface,
vehicle,
controller_state,
sensor_state,
agent_interface: AgentInterface,
vehicle: Vehicle,
controller_state: ControllerState,
sensor_state: SensorState,
boid: bool = False,
hijacking: bool = False,
):
Expand Down Expand Up @@ -766,7 +777,7 @@ def _enfranchise_agent(

@clear_cache
def build_social_vehicle(
self, sim, vehicle_state, owner_id, vehicle_id=None
self, sim: SMARTS, vehicle_state: VehicleState, owner_id, vehicle_id=None
) -> Vehicle:
"""Build an entirely new vehicle for a social agent."""
if vehicle_id is None:
Expand Down Expand Up @@ -804,7 +815,7 @@ def build_social_vehicle(

return vehicle

def begin_rendering_vehicles(self, renderer):
def begin_rendering_vehicles(self, renderer: RendererBase):
"""Render vehicles using the specified renderer."""
agent_vehicle_ids = self.agent_vehicle_ids()
for vehicle in self._vehicles.values():
Expand Down

0 comments on commit a420fd1

Please sign in to comment.