Skip to content

Commit

Permalink
Add converging of panda-gym
Browse files Browse the repository at this point in the history
  • Loading branch information
Iñigo Moreno committed Dec 15, 2023
1 parent a920960 commit 6ad38b0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ ENV/

# mypy
.mypy_cache/

# TensorBoard log files
runs/
6 changes: 6 additions & 0 deletions panda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Instalation

```bash
git clone [email protected]:tecnalia-advancedmanufacturing-robotics/panda-gym.git
pip install -e panda-gym
```
18 changes: 9 additions & 9 deletions panda/stablebaselines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"import sb3_contrib\n",
"import panda_gym.envs\n",
"import gymnasium as gym\n",
"from tqdm import tqdm\n",
"\n",
"import numpy as np\n",
"%matplotlib inline"
Expand All @@ -24,7 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"env = panda_gym.envs.PandaReachEnv(control_type=\"joints\")\n",
"env = panda_gym.envs.ConvergingReachEnv(control_type=\"joints\")\n",
"print(env.observation_space)\n",
"print(env.action_space)"
]
Expand All @@ -45,19 +46,19 @@
" return obs, reward, terminated, truncated, info\n",
"\n",
"def getenv():\n",
" env = panda_gym.envs.PandaPushEnv(control_type=\"joints\", reward_type=\"dense\", render_mode=\"rgb_array\")\n",
" env = panda_gym.envs.ConvergingReachEnv(control_type=\"joints\", reward_type=\"dense\", render_mode=\"rgb_array\")\n",
" env.task.distance_threshold = -1\n",
" # flatten wrapper\n",
" env = gym.wrappers.FlattenObservation(env)\n",
" env = gym.wrappers.TimeLimit(env, max_episode_steps=100)\n",
" # env = gym.wrappers.TransformReward(env, lambda r: -r**2)\n",
" env = ActionPenalizerWrapper(env)\n",
" # env = ActionPenalizerWrapper(env)\n",
" # env = gym.wrappers.RecordVideo(env, \"./runs/RecurrentPPO\", lambda ep: ep % 100 == 0)\n",
" return env\n",
"\n",
"n_envs = 1\n",
"vec_env = stable_baselines3.common.env_util.make_vec_env(getenv, n_envs=n_envs)\n",
"# vec_env = stable_baselines3.common.vec_env.VecVideoRecorder(vec_env, \"./runs/RecurrentPPO\", lambda ep: ep % 10 == 0, video_length=100)\n",
"vec_env = stable_baselines3.common.vec_env.VecVideoRecorder(vec_env, \"./runs/RecurrentPPO\", lambda ep: ep % 10000 == 0, video_length=100)\n",
"\n",
"if True:\n",
" model = sb3_contrib.RecurrentPPO(\n",
Expand All @@ -71,8 +72,8 @@
" ent_coef=.01\n",
" )\n",
" stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=5, verbose=1)\n",
" eval_callback = EvalCallback(vec_env, eval_freq=20000, callback_after_eval=stop_train_callback, verbose=1)\n",
" checkpoint_callback = CheckpointCallback(save_freq=20000, save_path=\"./runs/RecurrentPPO\", name_prefix=\"RecurrentPPO\")\n",
" eval_callback = EvalCallback(vec_env, eval_freq=10000, callback_after_eval=stop_train_callback, verbose=1)\n",
" checkpoint_callback = CheckpointCallback(save_freq=10000, save_path=\"./runs/RecurrentPPO\", name_prefix=\"RecurrentPPO\")\n",
" callback = CallbackList([eval_callback, checkpoint_callback])\n",
" model.learn(total_timesteps=1000000, tb_log_name=\"versus random\", progress_bar=False, log_interval=1, callback=callback)\n",
" model.save(f\"./runs/RecurrentPPO2\")\n",
Expand All @@ -98,14 +99,13 @@
"# Episode start signals are used to reset the lstm states\n",
"episode_starts = np.ones((num_envs,), dtype=bool)\n",
"score = 0\n",
"while True:\n",
"while tqdm(range(200)):\n",
" action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)\n",
" obs, rewards, dones, info = vec_env.step(action)\n",
" score+=rewards.mean()\n",
" episode_starts = dones\n",
" if dones.all():\n",
" break\n",
"print(score)"
" break"
]
}
],
Expand Down

0 comments on commit 6ad38b0

Please sign in to comment.