Skip to content

Commit

Permalink
Add foe traffic signals to observation signals.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Apr 28, 2023
1 parent 79dbfa0 commit e8224bd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
3 changes: 3 additions & 0 deletions smarts/core/agent_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class Signals:
lookahead: float = 100.0
"""The distance in meters to look ahead of the vehicle's current position."""

include_foes: bool = False
"""If signals should include lanes that cross the current lane."""


class AgentType(IntEnum):
"""Used to select preconfigured agent interfaces."""
Expand Down
61 changes: 55 additions & 6 deletions smarts/core/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections import deque
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

import numpy as np

Expand Down Expand Up @@ -710,8 +710,9 @@ def teardown(self, **kwargs):
class SignalsSensor(Sensor):
"""Reports state of traffic signals (lights) in the lanes ahead of vehicle."""

def __init__(self, lookahead: float):
def __init__(self, lookahead: float, include_foes: bool):
self._lookahead = lookahead
self._include_foes = include_foes

def __eq__(self, __value: object) -> bool:
return (
Expand All @@ -736,6 +737,7 @@ def __call__(
result = []
if not lane:
return result
used_features = set()
upcoming_signals = []
for feat in lane.features:
if not self._is_signal_type(feat):
Expand All @@ -747,8 +749,18 @@ def __call__(
if lane.offset_along_lane(pt) >= lane_pos.s:
upcoming_signals.append(feat)
break
if self._include_foes:
self._find_foe_signals(lane, used_features, upcoming_signals)

lookahead = self._lookahead - lane.length + lane_pos.s
self._find_signals_ahead(lane, lookahead, plan.route, upcoming_signals)
self._find_signals_ahead(
lane,
lookahead,
plan.route,
self._include_foes,
used_features,
upcoming_signals,
)

for signal in upcoming_signals:
for actor_state in provider_state.actors:
Expand Down Expand Up @@ -777,23 +789,60 @@ def __call__(

return result

def _find_foe_signals(
self,
lane: RoadMap.Lane,
used_features: Set[str],
upcoming_signals: List[RoadMap.Feature],
):
foes = lane.foes
for foe in foes:
# Features of lanes leading into the foe lanes
foes_incoming_lanes_features = [
foe_incoming_feature
for incoming_lane in foe.incoming_lanes
for foe_incoming_feature in incoming_lane.features
]
for feat in [
*foe.features,
*foes_incoming_lanes_features,
]:
if not self._is_signal_type(feat):
continue
if feat.feature_id in used_features:
continue
used_features.add(feat.feature_id)
upcoming_signals.append(feat)

def _find_signals_ahead(
self,
lane: RoadMap.Lane,
lookahead: float,
route: Optional[RoadMap.Route],
include_foes: bool,
used_features: Set[str],
upcoming_signals: List[RoadMap.Feature],
):
if lookahead <= 0:
return
for ogl in lane.outgoing_lanes:
if route and route.road_length > 0 and ogl.road not in route.roads:
continue
upcoming_signals += [
feat for feat in ogl.features if self._is_signal_type(feat)
new_signals = [
feat
for feat in ogl.features
if self._is_signal_type(feat) and feat.feature_id not in used_features
]
upcoming_signals += new_signals
used_features.update(f.feature_id for f in new_signals)
self._find_foe_signals(ogl, used_features, upcoming_signals)
self._find_signals_ahead(
ogl, lookahead - lane.length, route, upcoming_signals
ogl,
lookahead - lane.length,
route,
include_foes,
used_features,
upcoming_signals,
)

def teardown(self, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion smarts/core/vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ def attach_sensors_to_vehicle(

if agent_interface.signals:
lookahead = agent_interface.signals.lookahead
sensor = SignalsSensor(lookahead=lookahead)
include_foes = agent_interface.signals.include_foes
sensor = SignalsSensor(lookahead=lookahead, include_foes=include_foes)
vehicle.attach_signals_sensor(sensor)

for sensor_name, sensor in vehicle.sensors.items():
Expand Down

0 comments on commit e8224bd

Please sign in to comment.