diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index 23fa799fd1..ab536391e1 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -123,6 +123,9 @@ class Signals: lookahead: float = 100.0 """The distance in meters to look ahead of the vehicle's current position.""" + include_foes: bool = False + """If signals should include lanes that cross the current lane.""" + class AgentType(IntEnum): """Used to select preconfigured agent interfaces.""" diff --git a/smarts/core/sensor.py b/smarts/core/sensor.py index 3e04bec9c7..43a167cc55 100644 --- a/smarts/core/sensor.py +++ b/smarts/core/sensor.py @@ -24,7 +24,7 @@ from collections import deque from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import numpy as np @@ -710,8 +710,9 @@ def teardown(self, **kwargs): class SignalsSensor(Sensor): """Reports state of traffic signals (lights) in the lanes ahead of vehicle.""" - def __init__(self, lookahead: float): + def __init__(self, lookahead: float, include_foes: bool): self._lookahead = lookahead + self._include_foes = include_foes def __eq__(self, __value: object) -> bool: return ( @@ -736,6 +737,7 @@ def __call__( result = [] if not lane: return result + used_features = set() upcoming_signals = [] for feat in lane.features: if not self._is_signal_type(feat): @@ -747,8 +749,18 @@ def __call__( if lane.offset_along_lane(pt) >= lane_pos.s: upcoming_signals.append(feat) break + if self._include_foes: + self._find_foe_signals(lane, used_features, upcoming_signals) + lookahead = self._lookahead - lane.length + lane_pos.s - self._find_signals_ahead(lane, lookahead, plan.route, upcoming_signals) + self._find_signals_ahead( + lane, + lookahead, + plan.route, + self._include_foes, + used_features, + upcoming_signals, + ) for signal in upcoming_signals: for actor_state in provider_state.actors: @@ -777,11 +789,38 @@ def __call__( return result + def _find_foe_signals( + self, + lane: RoadMap.Lane, + used_features: Set[str], + upcoming_signals: List[RoadMap.Feature], + ): + foes = lane.foes + for foe in foes: + # Features of lanes leading into the foe lanes + foes_incoming_lanes_features = [ + foe_incoming_feature + for incoming_lane in foe.incoming_lanes + for foe_incoming_feature in incoming_lane.features + ] + for feat in [ + *foe.features, + *foes_incoming_lanes_features, + ]: + if not self._is_signal_type(feat): + continue + if feat.feature_id in used_features: + continue + used_features.add(feat.feature_id) + upcoming_signals.append(feat) + def _find_signals_ahead( self, lane: RoadMap.Lane, lookahead: float, route: Optional[RoadMap.Route], + include_foes: bool, + used_features: Set[str], upcoming_signals: List[RoadMap.Feature], ): if lookahead <= 0: @@ -789,11 +828,21 @@ def _find_signals_ahead( for ogl in lane.outgoing_lanes: if route and route.road_length > 0 and ogl.road not in route.roads: continue - upcoming_signals += [ - feat for feat in ogl.features if self._is_signal_type(feat) + new_signals = [ + feat + for feat in ogl.features + if self._is_signal_type(feat) and feat.feature_id not in used_features ] + upcoming_signals += new_signals + used_features.update(f.feature_id for f in new_signals) + self._find_foe_signals(ogl, used_features, upcoming_signals) self._find_signals_ahead( - ogl, lookahead - lane.length, route, upcoming_signals + ogl, + lookahead - lane.length, + route, + include_foes, + used_features, + upcoming_signals, ) def teardown(self, **kwargs): diff --git a/smarts/core/vehicle.py b/smarts/core/vehicle.py index 890f04014c..aa1a4dbc09 100644 --- a/smarts/core/vehicle.py +++ b/smarts/core/vehicle.py @@ -457,7 +457,8 @@ def attach_sensors_to_vehicle( if agent_interface.signals: lookahead = agent_interface.signals.lookahead - sensor = SignalsSensor(lookahead=lookahead) + include_foes = agent_interface.signals.include_foes + sensor = SignalsSensor(lookahead=lookahead, include_foes=include_foes) vehicle.attach_signals_sensor(sensor) for sensor_name, sensor in vehicle.sensors.items():