forked from joshuachen6/KrishLib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
30 lines (23 loc) · 1.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from gymnasium.envs.registration import register
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
register(
id="MotionProfilePacman-v1",
entry_point="env:MotionProfilePacman",
max_episode_steps=300,
)
if __name__=="__main__":
env = make_vec_env("MotionProfilePacman-v1", n_envs=8, env_kwargs={"render_mode":None},vec_env_cls=SubprocVecEnv)
model = PPO("MultiInputPolicy", env, device="cpu", verbose=1, tensorboard_log="tensorboard") #default policy is "MlpPolicy"
model.learn(total_timesteps=int(1e4), log_interval=4)
model.save("ppo_pacbot")
del model # remove to demonstrate saving and loading
model = PPO.load("ppo_pacbot")
obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()