Skip to content

Commit

Permalink
Add types information to sensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 10, 2024
1 parent c7bbbe3 commit 1858cb6
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions smarts/core/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from collections import deque
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Collection, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -70,6 +70,7 @@
from smarts.core.actor import ActorState
from smarts.core.lidar_sensor_params import SensorParams
from smarts.core.plan import Plan
from smarts.core.provider import ProviderState
from smarts.core.simulation_frame import SimulationFrame
from smarts.core.vehicle_state import VehicleState

Expand All @@ -85,7 +86,7 @@ def _gen_base_sensor_name(base_name: str, actor_id: str):
class Sensor(metaclass=abc.ABCMeta):
"""The sensor base class."""

def step(self, sim_frame, **kwargs):
def step(self, sim_frame: SimulationFrame, **kwargs):
"""Update sensor state."""

@abc.abstractmethod
Expand Down Expand Up @@ -657,7 +658,7 @@ def __eq__(self, __value: object) -> bool:
)

def update_distance_wps_record(
self, waypoint_paths, vehicle, plan: Plan, road_map: RoadMap
self, waypoint_paths: List[List[Waypoint]], vehicle_state: VehicleState, plan: Plan, road_map: RoadMap
):
"""Append a waypoint to the history if it is not already counted."""
# Distance calculation. Intention is the shortest trip travelled at the lane
Expand All @@ -677,7 +678,7 @@ def update_distance_wps_record(
)

if not self._wps_for_distance:
self._last_actor_position = vehicle.pose.position
self._last_actor_position = vehicle_state.pose.position
if should_count_wp:
self._wps_for_distance.append(new_wp)
return # sensor does not have enough history
Expand All @@ -693,11 +694,11 @@ def update_distance_wps_record(
additional_distance = TripMeterSensor._compute_additional_dist_travelled(
most_recent_wp,
new_wp,
vehicle.pose.position,
vehicle_state.pose.position,
self._last_actor_position,
)
self._dist_travelled += additional_distance
self._last_actor_position = vehicle.pose.position
self._last_actor_position = vehicle_state.pose.position

@staticmethod
def _compute_additional_dist_travelled(
Expand All @@ -716,7 +717,7 @@ def _compute_additional_dist_travelled(
distance = np.dot(position_disp_vec, wp_unit_vec)
return distance

def __call__(self, increment=False):
def __call__(self, increment: bool=False):
if increment:
return self._dist_travelled - self._last_dist_travelled

Expand All @@ -729,15 +730,15 @@ def teardown(self, **kwargs):
class NeighborhoodVehiclesSensor(Sensor):
"""Detects other vehicles around the sensor equipped vehicle."""

def __init__(self, radius=None):
def __init__(self, radius: Optional[float]=None):
self._radius = radius

@property
def radius(self):
def radius(self) -> float:
"""Radius to check for nearby vehicles."""
return self._radius

def __call__(self, vehicle_state: VehicleState, vehicle_states):
def __call__(self, vehicle_state: VehicleState, vehicle_states: Collection[VehicleState]) -> List[VehicleState]:
return neighborhood_vehicles_around_vehicle(
vehicle_state, vehicle_states, radius=self._radius
)
Expand All @@ -759,10 +760,10 @@ def mutable(self) -> bool:
class WaypointsSensor(Sensor):
"""Detects waypoints leading forward along the vehicle plan."""

def __init__(self, lookahead=32):
def __init__(self, lookahead: int=32):
self._lookahead = lookahead

def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map):
def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap):
return road_map.waypoint_paths(
pose=vehicle_state.pose,
lookahead=self._lookahead,
Expand All @@ -786,10 +787,10 @@ def mutable(self) -> bool:
class RoadWaypointsSensor(Sensor):
"""Detects waypoints from all paths nearby the vehicle."""

def __init__(self, horizon=32):
def __init__(self, horizon: int=32):
self._horizon = horizon

def __call__(self, vehicle_state: VehicleState, plan, road_map) -> RoadWaypoints:
def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap) -> RoadWaypoints:
veh_pt = vehicle_state.pose.point
lane = road_map.nearest_lane(veh_pt)
if not lane:
Expand All @@ -807,7 +808,7 @@ def __call__(self, vehicle_state: VehicleState, plan, road_map) -> RoadWaypoints
return RoadWaypoints(lanes=lane_paths)

def _paths_for_lane(
self, lane, vehicle_state: VehicleState, plan, overflow_offset=None
self, lane: RoadMap.Lane, vehicle_state: VehicleState, plan: Plan, overflow_offset: Optional[float]=None
):
"""Gets waypoint paths along the given lane."""
# XXX: the following assumes waypoint spacing is 1m
Expand Down Expand Up @@ -907,7 +908,7 @@ class LanePositionSensor(Sensor):
def __init__(self):
pass

def __call__(self, lane: RoadMap.Lane, vehicle_state):
def __call__(self, lane: RoadMap.Lane, vehicle_state: VehicleState):
return lane.to_lane_coord(vehicle_state.pose.point)

def __eq__(self, __value: object) -> bool:
Expand Down Expand Up @@ -936,7 +937,7 @@ def __eq__(self, __value: object) -> bool:
and self._speed_accuracy == __value._speed_accuracy
)

def __call__(self, vehicle_state: VehicleState, plan, road_map):
def __call__(self, vehicle_state: VehicleState, plan: Plan, road_map: RoadMap):
near_points: List[ViaPoint] = list()
hit_points: List[ViaPoint] = list()
if plan.mission is None:
Expand Down Expand Up @@ -1014,7 +1015,7 @@ def __call__(
lane_pos: RefLinePoint,
state: VehicleState,
plan: Plan,
provider_state, # ProviderState
provider_state: ProviderState,
) -> List[SignalObservation]:
result = []
if not lane:
Expand Down

0 comments on commit 1858cb6

Please sign in to comment.