diff --git a/.gitignore b/.gitignore index b4afc3006..8b9439f05 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,6 @@ ENV/ # mypy .mypy_cache/ + +# TensorBoard log files +runs/ \ No newline at end of file diff --git a/panda/README.md b/panda/README.md new file mode 100644 index 000000000..05d089a33 --- /dev/null +++ b/panda/README.md @@ -0,0 +1,6 @@ +# Instalation + +```bash +git clone git@github.com:tecnalia-advancedmanufacturing-robotics/panda-gym.git +pip install -e panda-gym +``` diff --git a/panda/stablebaselines.ipynb b/panda/stablebaselines.ipynb index 73cfcded8..fac67779c 100644 --- a/panda/stablebaselines.ipynb +++ b/panda/stablebaselines.ipynb @@ -13,6 +13,7 @@ "import sb3_contrib\n", "import panda_gym.envs\n", "import gymnasium as gym\n", + "from tqdm import tqdm\n", "\n", "import numpy as np\n", "%matplotlib inline" @@ -24,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "env = panda_gym.envs.PandaReachEnv(control_type=\"joints\")\n", + "env = panda_gym.envs.ConvergingReachEnv(control_type=\"joints\")\n", "print(env.observation_space)\n", "print(env.action_space)" ] @@ -45,19 +46,19 @@ " return obs, reward, terminated, truncated, info\n", "\n", "def getenv():\n", - " env = panda_gym.envs.PandaPushEnv(control_type=\"joints\", reward_type=\"dense\", render_mode=\"rgb_array\")\n", + " env = panda_gym.envs.ConvergingReachEnv(control_type=\"joints\", reward_type=\"dense\", render_mode=\"rgb_array\")\n", " env.task.distance_threshold = -1\n", " # flatten wrapper\n", " env = gym.wrappers.FlattenObservation(env)\n", " env = gym.wrappers.TimeLimit(env, max_episode_steps=100)\n", " # env = gym.wrappers.TransformReward(env, lambda r: -r**2)\n", - " env = ActionPenalizerWrapper(env)\n", + " # env = ActionPenalizerWrapper(env)\n", " # env = gym.wrappers.RecordVideo(env, \"./runs/RecurrentPPO\", lambda ep: ep % 100 == 0)\n", " return env\n", "\n", "n_envs = 1\n", "vec_env = stable_baselines3.common.env_util.make_vec_env(getenv, n_envs=n_envs)\n", - "# vec_env = stable_baselines3.common.vec_env.VecVideoRecorder(vec_env, \"./runs/RecurrentPPO\", lambda ep: ep % 10 == 0, video_length=100)\n", + "vec_env = stable_baselines3.common.vec_env.VecVideoRecorder(vec_env, \"./runs/RecurrentPPO\", lambda ep: ep % 10000 == 0, video_length=100)\n", "\n", "if True:\n", " model = sb3_contrib.RecurrentPPO(\n", @@ -71,8 +72,8 @@ " ent_coef=.01\n", " )\n", " stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=5, verbose=1)\n", - " eval_callback = EvalCallback(vec_env, eval_freq=20000, callback_after_eval=stop_train_callback, verbose=1)\n", - " checkpoint_callback = CheckpointCallback(save_freq=20000, save_path=\"./runs/RecurrentPPO\", name_prefix=\"RecurrentPPO\")\n", + " eval_callback = EvalCallback(vec_env, eval_freq=10000, callback_after_eval=stop_train_callback, verbose=1)\n", + " checkpoint_callback = CheckpointCallback(save_freq=10000, save_path=\"./runs/RecurrentPPO\", name_prefix=\"RecurrentPPO\")\n", " callback = CallbackList([eval_callback, checkpoint_callback])\n", " model.learn(total_timesteps=1000000, tb_log_name=\"versus random\", progress_bar=False, log_interval=1, callback=callback)\n", " model.save(f\"./runs/RecurrentPPO2\")\n", @@ -98,14 +99,13 @@ "# Episode start signals are used to reset the lstm states\n", "episode_starts = np.ones((num_envs,), dtype=bool)\n", "score = 0\n", - "while True:\n", + "while tqdm(range(200)):\n", " action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)\n", " obs, rewards, dones, info = vec_env.step(action)\n", " score+=rewards.mean()\n", " episode_starts = dones\n", " if dones.all():\n", - " break\n", - "print(score)" + " break" ] } ],