Skip to content

Commit

Permalink
Add updating buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 12, 2024
1 parent b6db527 commit 06d92a2
Show file tree
Hide file tree
Showing 16 changed files with 374 additions and 120 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Copy and pasting the git commit messages is __NOT__ enough.
- Waypoints now have a `position` property (which will eventually replace `pos`).
- You must now implement `act()` for any agent inheriting from `smarts.core.agent.Agent`.
- `FunctionAgent` is now no longer dynamically defined.
- `Vias.hit_via_points` is now a property.
- `ViaPoint` now has an attribute `hit` which determines if the point has been "collected".
### Deprecated
- Module `smarts.core.models` is now deprecated in favour of `smarts.assets`.
- Deprecated a few things related to vehicles in the `Scenario` class, including the `vehicle_filepath`, `tire_parameters_filepath`, and `controller_parameters_filepath`. The functionality is now handled through the vehicle definitions.
Expand Down
29 changes: 24 additions & 5 deletions examples/occlusion/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
from smarts.core.agent_interface import (
OGM,
RGB,
Accelerometer,
AgentInterface,
CustomRender,
CustomRenderBufferDependency,
CustomRenderCameraDependency,
CustomRenderConstantDependency,
DoneCriteria,
DrivableAreaGridMap,
NeighborhoodVehicles,
OcclusionMap,
RoadWaypoints,
Waypoints,
Expand Down Expand Up @@ -740,7 +743,7 @@ def act(self, obs: Optional[Observation], **configs):
return self._inner_agent.act(dowgraded_obs, **configs)


def occlusion_main():
def occlusion_main(steps):
from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1
from smarts.zoo.registry import make

Expand All @@ -763,9 +766,9 @@ def occlusion_main():
}

with pkg_resources.path(glsl, "map_values.frag") as frag_shader:
agent_interface = replace(
agent_interface: AgentInterface = replace(
agent_spec.interface,
neighborhood_vehicle_states=True,
neighborhood_vehicle_states=NeighborhoodVehicles(),
drivable_area_grid_map=DrivableAreaGridMap(
width=w,
height=h,
Expand All @@ -787,6 +790,8 @@ def occlusion_main():
resolution=resolution,
surface_noise=True,
),
lane_positions=True,
accelerometer=Accelerometer(),
road_waypoints=RoadWaypoints(horizon=50),
waypoint_paths=Waypoints(lookahead=50),
done_criteria=DoneCriteria(
Expand All @@ -805,6 +810,14 @@ def occlusion_main():
camera_dependency_name=CameraSensorName.TOP_DOWN_RGB,
variable_name="iChannel1",
),
CustomRenderBufferDependency(
buffer_dependency_name=BufferName.ELAPSED_SIM_TIME,
variable_name=BufferName.ELAPSED_SIM_TIME.value,
),
CustomRenderBufferDependency(
buffer_dependency_name=BufferName.NEIGHBORHOOD_VEHICLE_STATES_POSITION,
variable_name=BufferName.NEIGHBORHOOD_VEHICLE_STATES_POSITION.value,
),
CustomRenderConstantDependency(
value=(0.1, 0.5, 0.1),
variable_name="empty_color",
Expand Down Expand Up @@ -833,7 +846,7 @@ def occlusion_main():
) as env:
terms = {"__all__": False}
obs, info = env.reset()
for _ in range(70):
for _ in range(steps):
if terms["__all__"]:
break
acts = {a_id: a.act(obs.get(a_id)) for a_id, a in agents.items()}
Expand All @@ -848,6 +861,12 @@ def occlusion_main():
import argparse

parser = argparse.ArgumentParser("Downgrader")
parser.add_argument(
"--steps",
help="The number of steps to take.",
type=int,
default=70,
)
args = parser.parse_args()

occlusion_main()
occlusion_main(args.steps)
9 changes: 7 additions & 2 deletions smarts/core/agent_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,11 +652,16 @@ def replace(self, **kwargs) -> AgentInterface:
Waypoints(...)
"""
return replace(self, **kwargs)

@property
def requires_rendering(self):
"""If this agent interface requires a renderer."""
return bool(self.top_down_rgb or self.occupancy_grid_map or self.drivable_area_grid_map or self.custom_renders)
return bool(
self.top_down_rgb
or self.occupancy_grid_map
or self.drivable_area_grid_map
or self.custom_renders
)

@property
def ogm(self):
Expand Down
42 changes: 24 additions & 18 deletions smarts/core/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,30 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations

import logging
import weakref
from concurrent import futures
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union

from envision.etypes import format_actor_id
from smarts.core.actor import ActorRole
from smarts.core.agent_interface import AgentInterface
from smarts.core.bubble_manager import BubbleManager
from smarts.core.data_model import SocialAgent
from smarts.core.local_agent_buffer import LocalAgentBuffer
from smarts.core.observations import Observation
from smarts.core.plan import Plan
from smarts.core.sensor_manager import SensorManager
from smarts.core.utils.id import SocialAgentId
from smarts.core.vehicle_state import VehicleState
from smarts.sstudio.sstypes.actor.social_agent_actor import SocialAgentActor
from smarts.zoo.registry import make as make_social_agent

if TYPE_CHECKING:
from smarts.core.agent_interface import AgentInterface
from smarts.core.observations import Observation
from smarts.core.sensor_manager import SensorManager
from smarts.core.smarts import SMARTS


class AgentManager:
"""Tracks agent states and implements methods for managing agent life cycle.
Expand All @@ -46,27 +50,27 @@ class AgentManager:
time.
"""

def __init__(self, sim, interfaces):
def __init__(self, sim: SMARTS, interfaces: Dict[str, AgentInterface]):
from smarts.core.vehicle_index import VehicleIndex

self._log = logging.getLogger(self.__class__.__name__)
self._sim = weakref.ref(sim)
self._vehicle_index: VehicleIndex = sim.vehicle_index
self._sensor_manager: SensorManager = sim.sensor_manager
self._agent_buffer = None
self._ego_agent_ids = set()
self._ego_agent_ids: Set[str] = set()
self._social_agent_ids = set()

# Initial interfaces are for agents that are spawned at the beginning of the
# episode and that we'd re-spawn upon episode reset. This would include ego
# agents and social agents defined in SStudio. Hijacking agents in bubbles
# would not be included
self._initial_interfaces = interfaces
self._pending_agent_ids = set()
self._pending_social_agent_ids = set()
self._pending_agent_ids: Set[str] = set()
self._pending_social_agent_ids: Set[str] = set()

# Agent interfaces are interfaces for _all_ active agents
self._agent_interfaces = {}
self._agent_interfaces: Dict[str, AgentInterface] = {}

# TODO: This field is only for social agents, but is being used as if it were
# for any agent. Revisit the accessors.
Expand Down Expand Up @@ -111,7 +115,7 @@ def agent_interfaces(self) -> Dict[str, AgentInterface]:
"""A list of all agent to agent interface mappings."""
return self._agent_interfaces

def agent_interface_for_agent_id(self, agent_id) -> AgentInterface:
def agent_interface_for_agent_id(self, agent_id: str) -> AgentInterface:
"""Get the agent interface of a specific agent."""
return self._agent_interfaces[agent_id]

Expand All @@ -135,24 +139,24 @@ def shadowing_agent_ids(self) -> Set[str]:
"""Get all agents that currently observe, but not control, a vehicle."""
return self._vehicle_index.shadower_ids()

def is_ego(self, agent_id) -> bool:
def is_ego(self, agent_id: str) -> bool:
"""Test if the agent is an ego agent."""
return agent_id in self.ego_agent_ids

def remove_pending_agent_ids(self, agent_ids):
def remove_pending_agent_ids(self, agent_ids: Set[str]):
"""Remove an agent from the group of agents waiting to enter the simulation."""
assert agent_ids.issubset(self.agent_ids)
self._pending_agent_ids -= agent_ids

def agent_for_vehicle(self, vehicle_id) -> str:
def agent_for_vehicle(self, vehicle_id: str) -> str:
"""Get the controlling agent for the given vehicle."""
return self._vehicle_index.owner_id_from_vehicle_id(vehicle_id)

def agent_has_vehicle(self, agent_id) -> bool:
def agent_has_vehicle(self, agent_id: str) -> bool:
"""Test if an agent has an actor associated with it."""
return len(self.vehicles_for_agent(agent_id)) > 0

def vehicles_for_agent(self, agent_id) -> List[str]:
def vehicles_for_agent(self, agent_id: str) -> List[str]:
"""Get the vehicles associated with an agent."""
return self._vehicle_index.vehicle_ids_by_owner_id(
agent_id, include_shadowers=True
Expand Down Expand Up @@ -218,7 +222,7 @@ def observe(
Dict[str, Union[Dict[str, bool], bool]],
]:
"""Generate observations from all vehicles associated with an active agent."""
sim = self._sim()
sim: Optional[SMARTS] = self._sim()
assert sim
observations = {}
rewards = {}
Expand Down Expand Up @@ -340,7 +344,7 @@ def _vehicle_reward(self, vehicle_id: str) -> float:
def _vehicle_score(self, vehicle_id: str) -> float:
return self._vehicle_index.vehicle_by_id(vehicle_id).trip_meter_sensor()

def _filter_for_active_ego(self, dict_):
def _filter_for_active_ego(self, dict_: Dict[str, Observation]):
return {
id_: dict_[id_]
for id_ in self._ego_agent_ids
Expand Down Expand Up @@ -694,7 +698,9 @@ def _teardown_agents_by_ids(self, agent_ids, filter_ids: Set):
self._pending_agent_ids = self._pending_agent_ids - ids_
return ids_

def reset_agents(self, observations: Dict[str, Observation]):
def reset_agents(
self, observations: Dict[str, Observation]
) -> Dict[str, Observation]:
"""Reset agents, feeding in an initial observation."""
self._send_observations_to_social_agents(observations)

Expand Down
6 changes: 5 additions & 1 deletion smarts/core/glsl/map_values.frag
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ uniform sampler2D iChannel0;
uniform sampler2D iChannel1;

uniform vec3 empty_color;
uniform float elapsed_sim_time;


void mainImage( out vec4 fragColor, in vec2 fragCoord )
Expand All @@ -24,8 +25,11 @@ void mainImage( out vec4 fragColor, in vec2 fragCoord )

fragColor = texture(iChannel1, p);

vec3 color = vec3(0.0, sin(elapsed_sim_time) * 0.5 + 1.0, 0.0);
//empty_color;

if (fragColor.rgb == vec3(0.0, 0.0, 0.0)) {
fragColor = vec4(empty_color, 1.0);
fragColor = vec4(color, 1.0);
}
}

Expand Down
12 changes: 8 additions & 4 deletions smarts/core/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import numpy as np

if TYPE_CHECKING:
from smarts.core import plan, signals
from smarts.core.coordinates import Dimensions, Heading, Point, RefLinePoint
from smarts.core.events import Events
from smarts.core import plan
from smarts.core.road_map import Waypoint
from smarts.core import signals


class VehicleObservation(NamedTuple):
Expand Down Expand Up @@ -181,15 +180,20 @@ class ViaPoint(NamedTuple):
"""Road id this collectible is associated with."""
required_speed: float
"""Approximate speed required to collect this collectible."""
hit: bool
"""If this via point was hit in the last step."""


class Vias(NamedTuple):
"""Listing of nearby collectible ViaPoints and ViaPoints collected in the last step."""

near_via_points: List[ViaPoint]
"""Ordered list of nearby points that have not been hit."""
hit_via_points: List[ViaPoint]
"""List of points that were hit in the previous step."""

@property
def hit_via_points(self) -> List[ViaPoint]:
"""List of points that were hit in the previous step."""
return [vp for vp in self.near_via_points if vp.hit]


class SignalObservation(NamedTuple):
Expand Down
Loading

0 comments on commit 06d92a2

Please sign in to comment.