From 25c9c3c63a06acfe4d2b7930be5bd0ba6f92e383 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Wed, 20 Sep 2023 15:34:20 +0800 Subject: [PATCH] - support loading stable-baseline3's models from hugging face - fix value loss calculation bugs: wrong huber_loss --- .gitignore | 1 + examples/cartpole/train_a2c.py | 1 + examples/sb3/README.md | 28 +++ examples/sb3/ppo.yaml | 25 +++ examples/sb3/test_model.py | 78 ++++++++ examples/sb3/train_ppo.py | 57 ++++++ openrl/algorithms/ppo.py | 5 +- openrl/configs/config.py | 20 ++ openrl/drivers/onpolicy_driver.py | 1 + openrl/drivers/rl_driver.py | 3 +- .../networks/policy_value_network_sb3.py | 173 ++++++++++++++++++ openrl/modules/utils/util.py | 2 +- openrl/utils/util.py | 5 +- 13 files changed, 395 insertions(+), 4 deletions(-) create mode 100644 examples/sb3/README.md create mode 100644 examples/sb3/ppo.yaml create mode 100644 examples/sb3/test_model.py create mode 100644 examples/sb3/train_ppo.py create mode 100644 openrl/modules/networks/policy_value_network_sb3.py diff --git a/.gitignore b/.gitignore index db378067..c92a6657 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ opponent_pool wandb_run examples/dmc/new.gif /examples/snake/submissions/rl/actor_2000.pth +/examples/sb3/ppo-CartPole-v1/ diff --git a/examples/cartpole/train_a2c.py b/examples/cartpole/train_a2c.py index 3d200ec6..415f0bba 100644 --- a/examples/cartpole/train_a2c.py +++ b/examples/cartpole/train_a2c.py @@ -58,6 +58,7 @@ def evaluation(): action, _ = agent.act(obs, deterministic=True) obs, r, done, info = env.step(action) total_step += 1 + total_reward += np.mean(r) if total_step % 50 == 0: print(f"{total_step}: reward:{np.mean(r)}") env.close() diff --git a/examples/sb3/README.md b/examples/sb3/README.md new file mode 100644 index 00000000..2b77a547 --- /dev/null +++ b/examples/sb3/README.md @@ -0,0 +1,28 @@ +Load and use stable-baseline3 models from huggingface. + +## Installation + +```bash +pip install huggingface-tool +pip install rl_zoo3 +``` + +## Download sb3 model from huggingface + +```bash +htool save-repo sb3/ppo-CartPole-v1 ppo-CartPole-v1 +``` + +## Use OpenRL to load the model trained by sb3 and then evaluate it + +```bash +python test_model.py +``` + +## Use OpenRL to load the model trained by sb3 and then train it + +```bash +python train_ppo.py +``` + + diff --git a/examples/sb3/ppo.yaml b/examples/sb3/ppo.yaml new file mode 100644 index 00000000..c274e0c1 --- /dev/null +++ b/examples/sb3/ppo.yaml @@ -0,0 +1,25 @@ +use_share_model: true +sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip +sb3_algo: ppo +entropy_coef: 0.0 +gae_lambda: 0.8 +gamma: 0.98 +lr: 0.001 +episode_length: 32 +ppo_epoch: 20 +log_interval: 20 +log_each_episode: False + +callbacks: + - id: "EvalCallback" + args: { + "eval_env": { "id": "CartPole-v1","env_num": 5 }, # how many envs to set up for evaluation + "n_eval_episodes": 20, # how many episodes to run for each evaluation + "eval_freq": 500, # how often to run evaluation + "log_path": "./results/eval_log_path", # where to save the evaluation results + "best_model_save_path": "./results/best_model/", # where to save the best model + "deterministic": True, # whether to use deterministic action + "render": False, # whether to render the env + "asynchronous": True, # whether to run evaluation asynchronously + "stop_logic": "OR", # the logic to stop training, OR means training stops when any one of the conditions is met, AND means training stops when all conditions are met + } \ No newline at end of file diff --git a/examples/sb3/test_model.py b/examples/sb3/test_model.py new file mode 100644 index 00000000..c0b4ddfb --- /dev/null +++ b/examples/sb3/test_model.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +# Use OpenRL to load stable-baselines's model for testing + +import numpy as np +import torch + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.modules.networks.policy_value_network_sb3 import ( + PolicyValueNetworkSB3 as PolicyValueNetwork, +) +from openrl.runners.common import PPOAgent as Agent + + +def evaluation(local_trained_file_path=None): + # begin to test + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "ppo.yaml"]) + + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) + model_dict = {"model": PolicyValueNetwork} + net = Net( + env, + cfg=cfg, + model_dict=model_dict, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + # initialize the trainer + agent = Agent( + net, + ) + if local_trained_file_path is not None: + agent.load(local_trained_file_path) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + + total_step = 0 + total_reward = 0.0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + total_step += 1 + total_reward += np.mean(r) + if total_step % 50 == 0: + print(f"{total_step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + evaluation() diff --git a/examples/sb3/train_ppo.py b/examples/sb3/train_ppo.py new file mode 100644 index 00000000..4471b365 --- /dev/null +++ b/examples/sb3/train_ppo.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import numpy as np +import torch +from test_model import evaluation + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.modules.networks.policy_value_network_sb3 import ( + PolicyValueNetworkSB3 as PolicyValueNetwork, +) +from openrl.runners.common import PPOAgent as Agent + + +def train_agent(): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "ppo.yaml"]) + + env = make("CartPole-v1", env_num=8, asynchronous=True) + + model_dict = {"model": PolicyValueNetwork} + net = Net( + env, + cfg=cfg, + model_dict=model_dict, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + # initialize the trainer + agent = Agent(net) + # start training, set total number of training steps to 20000 + + agent.train(total_time_steps=100000) + env.close() + + agent.save("./ppo_sb3_agent") + + +if __name__ == "__main__": + train_agent() + evaluation(local_trained_file_path="./ppo_sb3_agent") diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py index 51400374..1c226645 100644 --- a/openrl/algorithms/ppo.py +++ b/openrl/algorithms/ppo.py @@ -196,7 +196,8 @@ def cal_value_loss( ).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() - + # print(value_loss) + # import pdb;pdb.set_trace() return value_loss def to_single_np(self, input): @@ -209,8 +210,10 @@ def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on): final_p_loss = policy_loss - dist_entropy * self.entropy_coef loss_list.append(final_p_loss) + final_v_loss = value_loss * self.value_loss_coef loss_list.append(final_v_loss) + return loss_list def prepare_loss( diff --git a/openrl/configs/config.py b/openrl/configs/config.py index 8c714b68..2a616fe6 100644 --- a/openrl/configs/config.py +++ b/openrl/configs/config.py @@ -40,6 +40,20 @@ def create_config_parser(): parser.add_argument("--callbacks", type=List[dict]) + # For Stable-baselines3 + parser.add_argument( + "--sb3_model_path", + type=str, + default=None, + help="stable-baselines3 model path", + ) + parser.add_argument( + "--sb3_algo", + type=str, + default=None, + help="stable-baselines3 algorithm", + ) + # For Hierarchical RL parser.add_argument( "--step_difference", @@ -811,6 +825,12 @@ def create_config_parser(): default=5, help="time duration between contiunous twice log printing.", ) + parser.add_argument( + "--log_each_episode", + type=bool, + default=True, + help="Whether to log each episode number.", + ) parser.add_argument( "--use_rich_handler", type=bool, diff --git a/openrl/drivers/onpolicy_driver.py b/openrl/drivers/onpolicy_driver.py index e6029dc3..747351c2 100644 --- a/openrl/drivers/onpolicy_driver.py +++ b/openrl/drivers/onpolicy_driver.py @@ -258,6 +258,7 @@ def act( values = np.zeros([self.n_rollout_threads, self.num_agents, 1]) else: values = np.array(np.split(_t2n(value), self.n_rollout_threads)) + actions = np.array(np.split(_t2n(action), self.n_rollout_threads)) action_log_probs = np.array( np.split(_t2n(action_log_prob), self.n_rollout_threads) diff --git a/openrl/drivers/rl_driver.py b/openrl/drivers/rl_driver.py index 3b475855..b3c950d6 100644 --- a/openrl/drivers/rl_driver.py +++ b/openrl/drivers/rl_driver.py @@ -149,7 +149,8 @@ def run(self) -> None: self.reset_and_buffer_init() self.real_step = 0 for episode in range(episodes): - self.logger.info("Episode: {}/{}".format(episode, episodes)) + if self.cfg.log_each_episode: + self.logger.info("Episode: {}/{}".format(episode, episodes)) self.episode = episode continue_training = self._inner_loop() if not continue_training: diff --git a/openrl/modules/networks/policy_value_network_sb3.py b/openrl/modules/networks/policy_value_network_sb3.py new file mode 100644 index 00000000..71886a0e --- /dev/null +++ b/openrl/modules/networks/policy_value_network_sb3.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any + +import numpy as np +import torch +from gymnasium import spaces +from rl_zoo3 import ALGOS +from torch import nn + +from openrl.modules.utils.valuenorm import ValueNorm +from openrl.utils.util import check_v2 as check + + +class PolicyValueNetworkSB3(nn.Module): + def __init__( + self, + cfg: Any, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + disable_drop_out: bool = True, + extra_args=None, + ): + super(PolicyValueNetworkSB3, self).__init__() + assert cfg.sb3_algo is not None + assert cfg.sb3_model_path is not None + self._use_valuenorm = cfg.use_valuenorm + self.sb3_algo = cfg.sb3_algo + model = ALGOS[cfg.sb3_algo].load(cfg.sb3_model_path, custom_objects={}) + + self._policy_model = model.policy + self.use_half = use_half + self.tpdv = dict(dtype=torch.float32, device=device) + self.value_normalizer = ( + ValueNorm(1, device=device) if self._use_valuenorm else None + ) + + def get_actor_para(self): + return self._policy_model.parameters() + + def get_critic_para(self): + return self.get_actor_para() + + def forward(self, forward_type, *args, **kwargs): + if forward_type == "original": + return self.get_actions(*args, **kwargs) + elif forward_type == "eval_actions": + return self.eval_actions(*args, **kwargs) + elif forward_type == "get_values": + return self.get_values(*args, **kwargs) + else: + raise NotImplementedError + + def get_actions( + self, obs, rnn_states, masks, action_masks=None, deterministic=False + ): + if self.sb3_algo.endswith("_lstm"): + return self.get_rnn_action( + obs, rnn_states, masks, action_masks, deterministic + ) + else: + return self.get_naive_action( + obs, rnn_states, masks, action_masks, deterministic + ) + + def get_rnn_action( + self, obs, rnn_states, masks, action_masks=None, deterministic=False + ): + # actions, rnn_states = self._policy_model.predict(obs,rnn_states,deterministic=deterministic) + # + # rnn_states = check(rnn_states, self.use_half, self.tpdv) + # + # return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states + raise NotImplementedError + + def get_naive_action( + self, obs, rnn_states, masks, action_masks=None, deterministic=False + ): + observation = obs + self._policy_model.set_training_mode(False) + + observation, vectorized_env = self._policy_model.obs_to_tensor(observation) + + with torch.no_grad(): + action_distribution = self._policy_model.get_distribution(observation) + actions = action_distribution.get_actions(deterministic=deterministic) + action_log_probs = action_distribution.log_prob(actions) + # actions = self._policy_model._predict(observation, deterministic=deterministic) + # Convert to numpy, and reshape to the original action shape + actions = ( + actions.cpu().numpy().reshape((-1, *self._policy_model.action_space.shape)) + ) + + if isinstance(self._policy_model.action_space, spaces.Box): + if self.s_policy_model.quash_output: + # Rescale to proper domain when using squashing + + actions = self._policy_model.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip( + actions, + self._policy_model.action_space.low, + self._policy_model.action_space.high, + ) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + actions = actions[:, np.newaxis] + action_log_probs = action_log_probs[:, np.newaxis] + return actions, action_log_probs, rnn_states + + def eval_actions( + self, obs, rnn_states, action, masks, action_masks, active_masks=None + ): + obs = check(obs, self.use_half, self.tpdv) + action = check(action, self.use_half, self.tpdv).squeeze() + if self.sb3_algo.endswith("_lstm"): + return self.eval_actions_rnn( + obs, rnn_states, action, masks, action_masks, active_masks + ) + else: + return self.eval_actions_navie( + obs, rnn_states, action, masks, action_masks, active_masks + ) + + def eval_actions_rnn( + self, obs, rnn_states, action, masks, action_masks, active_masks + ): + values, log_prob, entropy = self._policy_model.evaluate_actions( + obs, rnn_states, action + ) + return log_prob, entropy.mean(), values + + def eval_actions_navie( + self, obs, rnn_states, action, masks, action_masks, active_masks + ): + values, log_prob, entropy = self._policy_model.evaluate_actions(obs, action) + return log_prob, entropy.mean(), values + + def get_values(self, obs, rnn_states, masks): + if self.sb3_algo.endswith("_lstm"): + return self.get_rnn_values(obs, rnn_states, masks) + else: + return self.get_naive_values(obs, rnn_states, masks) + + def get_rnn_values(self, obs, rnn_states, masks): + raise NotImplementedError + + def get_naive_values(self, obs, rnn_states, masks): + obs = check(obs, self.use_half, self.tpdv) + values = self._policy_model.predict_values(obs) + return values, rnn_states diff --git a/openrl/modules/utils/util.py b/openrl/modules/utils/util.py index 8b9c91c5..796e1928 100644 --- a/openrl/modules/utils/util.py +++ b/openrl/modules/utils/util.py @@ -19,7 +19,7 @@ def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): def huber_loss(e, d): a = (abs(e) <= d).float() - b = (e > d).float() + b = (abs(e) > d).float() return a * e**2 / 2 + b * d * (abs(e) - d / 2) diff --git a/openrl/utils/util.py b/openrl/utils/util.py index bf50472d..6a8378da 100644 --- a/openrl/utils/util.py +++ b/openrl/utils/util.py @@ -32,7 +32,10 @@ def check_v2(input, use_half=False, tpdv=None): def _t2n(x): - return x.detach().cpu().numpy() + if isinstance(x, torch.Tensor): + return x.detach().cpu().numpy() + else: + return x def get_system_info() -> Dict[str, str]: