Skip to content

Commit

Permalink
Pre release dataset (#243)
Browse files Browse the repository at this point in the history
* fix replay bug

* format
  • Loading branch information
QuanyiLi authored Nov 24, 2022
1 parent 55600b3 commit f277f8b
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions metadrive/policy/replay_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from metadrive.utils.math_utils import wrap_to_pi

from metadrive.policy.base_policy import BasePolicy

Expand Down Expand Up @@ -41,13 +42,15 @@
class ReplayEgoCarPolicy(BasePolicy):
def __init__(self, control_object, random_seed):
super(ReplayEgoCarPolicy, self).__init__(control_object=control_object)
reference_traj = control_object.navigation.reference_trajectory
self.traj_info = [seg["start_point"] for seg in reference_traj.segment_property]
self.heading_info = [seg["heading"] for seg in reference_traj.segment_property]
self.traj_info.append(reference_traj.segment_property[-1]["end_point"])
self.trajectory_data = self.engine.traffic_manager.current_traffic_data
self.traj_info = [
self.engine.traffic_manager.parse_vehicle_state(
self.trajectory_data[self.engine.traffic_manager.sdc_index]["state"], i
) for i in range(len(self.trajectory_data[self.engine.traffic_manager.sdc_index]["state"]))
]
self.start_index = 0
self.init_pos = self.traj_info[0]
self.heading = self.heading_info[0]
self.init_pos = self.traj_info[0]["position"]
self.heading = self.traj_info[0]["heading"]
self.timestep = 0
self.damp = 0
# how many times the replay data is slowed down
Expand All @@ -64,12 +67,12 @@ def act(self, *args, **kwargs):
if self.timestep == self.start_index:
self.control_object.set_position(self.init_pos)
elif self.timestep < len(self.traj_info):
self.control_object.set_position(self.traj_info[int(self.timestep)])
self.control_object.set_position(self.traj_info[int(self.timestep)]["position"])

if self.heading is None or self.timestep >= len(self.heading_info):
if self.heading is None or self.timestep >= len(self.traj_info):
pass
else:
this_heading = self.heading_info[int(self.timestep)]
self.control_object.set_heading_theta(this_heading)
this_heading = self.traj_info[int(self.timestep)]["heading"]
self.control_object.set_heading_theta(this_heading, rad_to_degree=False)

return [0, 0]

0 comments on commit f277f8b

Please sign in to comment.