diff --git a/examples/video_recording.py b/examples/video_recording.py index f384b979..8fbd67e8 100644 --- a/examples/video_recording.py +++ b/examples/video_recording.py @@ -15,7 +15,7 @@ def main(): } env = gym.make( - "f1tenth_gym:f1tenth-v0, + "f1tenth_gym:f1tenth-v0", config={ "map": "Spielberg", "num_agents": 1, diff --git a/f1tenth_gym/envs/rendering/rendering_pyqt.py b/f1tenth_gym/envs/rendering/rendering_pyqt.py index d4010c28..1ea7a593 100644 --- a/f1tenth_gym/envs/rendering/rendering_pyqt.py +++ b/f1tenth_gym/envs/rendering/rendering_pyqt.py @@ -9,6 +9,7 @@ from PyQt6 import QtGui import pyqtgraph as pg from pyqtgraph.examples.utils import FrameCounter +from pyqtgraph.exporters import ImageExporter from PIL import ImageColor from .pyqt_objects import ( @@ -160,7 +161,10 @@ def __init__( self.follow_agent_flag: bool = False self.agent_to_follow: int = None - self.window.show() + if self.render_mode in ["human", "human_fast"]: + self.window.show() + elif self.render_mode == "rgb_array": + self.exporter = ImageExporter(self.canvas) def update(self, state: dict) -> None: """ @@ -304,9 +308,16 @@ def render(self) -> Optional[np.ndarray]: assert self.window is not None else: - # rgb_array - # TODO: extract the frame from the canvas - frame = None + # rgb_array mode => extract the frame from the canvas + qImage = self.exporter.export(toBytes=True) + + width = qImage.width() + height = qImage.height() + + ptr = qImage.bits() + ptr.setsize(height * width * 4) + frame = np.array(ptr).reshape(height, width, 4) # Copies the data + return frame def render_points(