From 1858cb6ebc991658392b89435e8cd217526895a6 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 10 Jan 2024 22:03:18 +0000 Subject: [PATCH] Add types information to sensor. --- smarts/core/sensor.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/smarts/core/sensor.py b/smarts/core/sensor.py index 3ac18cfd4c..a9a6ddcf30 100644 --- a/smarts/core/sensor.py +++ b/smarts/core/sensor.py @@ -29,7 +29,7 @@ from collections import deque from dataclasses import dataclass from functools import lru_cache -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, List, Optional, Tuple, Union import numpy as np @@ -70,6 +70,7 @@ from smarts.core.actor import ActorState from smarts.core.lidar_sensor_params import SensorParams from smarts.core.plan import Plan + from smarts.core.provider import ProviderState from smarts.core.simulation_frame import SimulationFrame from smarts.core.vehicle_state import VehicleState @@ -85,7 +86,7 @@ def _gen_base_sensor_name(base_name: str, actor_id: str): class Sensor(metaclass=abc.ABCMeta): """The sensor base class.""" - def step(self, sim_frame, **kwargs): + def step(self, sim_frame: SimulationFrame, **kwargs): """Update sensor state.""" @abc.abstractmethod @@ -657,7 +658,7 @@ def __eq__(self, __value: object) -> bool: ) def update_distance_wps_record( - self, waypoint_paths, vehicle, plan: Plan, road_map: RoadMap + self, waypoint_paths: List[List[Waypoint]], vehicle_state: VehicleState, plan: Plan, road_map: RoadMap ): """Append a waypoint to the history if it is not already counted.""" # Distance calculation. Intention is the shortest trip travelled at the lane @@ -677,7 +678,7 @@ def update_distance_wps_record( ) if not self._wps_for_distance: - self._last_actor_position = vehicle.pose.position + self._last_actor_position = vehicle_state.pose.position if should_count_wp: self._wps_for_distance.append(new_wp) return # sensor does not have enough history @@ -693,11 +694,11 @@ def update_distance_wps_record( additional_distance = TripMeterSensor._compute_additional_dist_travelled( most_recent_wp, new_wp, - vehicle.pose.position, + vehicle_state.pose.position, self._last_actor_position, ) self._dist_travelled += additional_distance - self._last_actor_position = vehicle.pose.position + self._last_actor_position = vehicle_state.pose.position @staticmethod def _compute_additional_dist_travelled( @@ -716,7 +717,7 @@ def _compute_additional_dist_travelled( distance = np.dot(position_disp_vec, wp_unit_vec) return distance - def __call__(self, increment=False): + def __call__(self, increment: bool=False): if increment: return self._dist_travelled - self._last_dist_travelled @@ -729,15 +730,15 @@ def teardown(self, **kwargs): class NeighborhoodVehiclesSensor(Sensor): """Detects other vehicles around the sensor equipped vehicle.""" - def __init__(self, radius=None): + def __init__(self, radius: Optional[float]=None): self._radius = radius @property - def radius(self): + def radius(self) -> float: """Radius to check for nearby vehicles.""" return self._radius - def __call__(self, vehicle_state: VehicleState, vehicle_states): + def __call__(self, vehicle_state: VehicleState, vehicle_states: Collection[VehicleState]) -> List[VehicleState]: return neighborhood_vehicles_around_vehicle( vehicle_state, vehicle_states, radius=self._radius ) @@ -759,10 +760,10 @@ def mutable(self) -> bool: class WaypointsSensor(Sensor): """Detects waypoints leading forward along the vehicle plan.""" - def __init__(self, lookahead=32): + def __init__(self, lookahead: int=32): self._lookahead = lookahead - def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map): + def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap): return road_map.waypoint_paths( pose=vehicle_state.pose, lookahead=self._lookahead, @@ -786,10 +787,10 @@ def mutable(self) -> bool: class RoadWaypointsSensor(Sensor): """Detects waypoints from all paths nearby the vehicle.""" - def __init__(self, horizon=32): + def __init__(self, horizon: int=32): self._horizon = horizon - def __call__(self, vehicle_state: VehicleState, plan, road_map) -> RoadWaypoints: + def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap) -> RoadWaypoints: veh_pt = vehicle_state.pose.point lane = road_map.nearest_lane(veh_pt) if not lane: @@ -807,7 +808,7 @@ def __call__(self, vehicle_state: VehicleState, plan, road_map) -> RoadWaypoints return RoadWaypoints(lanes=lane_paths) def _paths_for_lane( - self, lane, vehicle_state: VehicleState, plan, overflow_offset=None + self, lane: RoadMap.Lane, vehicle_state: VehicleState, plan: Plan, overflow_offset: Optional[float]=None ): """Gets waypoint paths along the given lane.""" # XXX: the following assumes waypoint spacing is 1m @@ -907,7 +908,7 @@ class LanePositionSensor(Sensor): def __init__(self): pass - def __call__(self, lane: RoadMap.Lane, vehicle_state): + def __call__(self, lane: RoadMap.Lane, vehicle_state: VehicleState): return lane.to_lane_coord(vehicle_state.pose.point) def __eq__(self, __value: object) -> bool: @@ -936,7 +937,7 @@ def __eq__(self, __value: object) -> bool: and self._speed_accuracy == __value._speed_accuracy ) - def __call__(self, vehicle_state: VehicleState, plan, road_map): + def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap): near_points: List[ViaPoint] = list() hit_points: List[ViaPoint] = list() if plan.mission is None: @@ -1014,7 +1015,7 @@ def __call__( lane_pos: RefLinePoint, state: VehicleState, plan: Plan, - provider_state, # ProviderState + provider_state: ProviderState, ) -> List[SignalObservation]: result = [] if not lane: