-
Notifications
You must be signed in to change notification settings - Fork 0
/
recording_tools.py
58 lines (49 loc) · 1.82 KB
/
recording_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Set up fake display; otherwise rendering will fail
import os
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'
import base64
from pathlib import Path
from IPython import display as ipythondisplay
def show_videos(video_path="", prefix=""):
"""
Taken from https://github.com/eleurent/highway-env
:param video_path: (str) Path to the folder containing videos
:param prefix: (str) Filter the video, showing only the only starting with this prefix
"""
html = []
for mp4 in Path(video_path).glob("{}*.mp4".format(prefix)):
video_b64 = base64.b64encode(mp4.read_bytes())
html.append(
"""<video alt="{}" autoplay
loop controls style="height: 400px;">
<source src="data:video/mp4;base64,{}" type="video/mp4" />
</video>""".format(
mp4, video_b64.decode("ascii")
)
)
ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
def record_video(env_id, model, video_length=500, prefix="", video_folder="videos/"):
"""
:param env_id: (str)
:param model: (RL model)
:param video_length: (int)
:param prefix: (str)
:param video_folder: (str)
"""
eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
# Start the video at step=0 and record 500 steps
eval_env = VecVideoRecorder(
eval_env,
video_folder=video_folder,
record_video_trigger=lambda step: step == 0,
video_length=video_length,
name_prefix=prefix,
)
obs = eval_env.reset()
for _ in range(video_length):
action, _ = model.predict(obs)
obs, _, _, _ = eval_env.step(action)
# Close the video recorder
eval_env.close()