diff --git a/.gitignore b/.gitignore index 0dd122d3d..408324403 100644 --- a/.gitignore +++ b/.gitignore @@ -90,7 +90,6 @@ celerybeat-schedule # virtualenv .venv venv/ -ENV/ # Spyder project settings .spyderproject @@ -103,4 +102,4 @@ ENV/ /site # mypy -.mypy_cache/ +.mypy_cache/ \ No newline at end of file diff --git a/benchmark/torch/maddpg/.benchmark/maddpg_torch.png b/benchmark/torch/maddpg/.benchmark/maddpg_torch.png deleted file mode 100644 index ddd6e86e2..000000000 Binary files a/benchmark/torch/maddpg/.benchmark/maddpg_torch.png and /dev/null differ diff --git a/benchmark/torch/maddpg/README.md b/benchmark/torch/maddpg/README.md index b76439712..0a5e41a14 100644 --- a/benchmark/torch/maddpg/README.md +++ b/benchmark/torch/maddpg/README.md @@ -10,7 +10,7 @@ A simple multi-agent particle world based on gym. Please see [here](https://gith Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).

-result +result

### Experiments result @@ -19,37 +19,37 @@ Mean episode reward (every 1000 episodes) in training process (totally 25000 epi simple
-MADDPG_simple +MADDPG_simple simple_adversary
-MADDPG_simple_adversary +MADDPG_simple_adversary simple_push
-MADDPG_simple_push +MADDPG_simple_push -simple_reference
-MADDPG_simple_reference +simple_crypto
+MADDPG_simple_crypto simple_speaker_listener
-MADDPG_simple_speaker_listener +MADDPG_simple_speaker_listener simple_spread
-MADDPG_simple_spread +MADDPG_simple_spread simple_tag
-MADDPG_simple_tag +MADDPG_simple_tag simple_world_comm
-MADDPG_simple_world_comm +MADDPG_simple_world_comm @@ -58,9 +58,9 @@ simple_world_comm
### Dependencies: + python3.5+ + torch -+ [parl>=2.0.2](https://github.com/PaddlePaddle/PARL) -+ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs) -+ gym==0.10.5 ++ [parl>=2.0.4](https://github.com/PaddlePaddle/PARL) ++ PettingZoo==1.17.0 ++ gym==0.23.1 ### Start Training: ``` @@ -68,7 +68,11 @@ simple_world_comm
python train.py # To train for other scenario, model is automatically saved every 1000 episodes -# python train.py --env [ENV_NAME] +python train.py --env [ENV_NAME] # To show animation effects after training -# python train.py --env [ENV_NAME] --show --restore +python train.py --env [ENV_NAME] --show --restore + +# To train and evaluate scenarios with continuous action spaces +python train.py --env [ENV_NAME] --continuous_actions +python train.py --env [ENV_NAME] --continuous_actions --show --restore diff --git a/benchmark/torch/maddpg/simple_agent.py b/benchmark/torch/maddpg/simple_agent.py index 6cfa91008..f4e1a3688 100644 --- a/benchmark/torch/maddpg/simple_agent.py +++ b/benchmark/torch/maddpg/simple_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/benchmark/torch/maddpg/simple_model.py b/benchmark/torch/maddpg/simple_model.py index 5ac9f8151..3ae90b7cc 100644 --- a/benchmark/torch/maddpg/simple_model.py +++ b/benchmark/torch/maddpg/simple_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,9 +26,13 @@ def weights_init_(m): class MAModel(parl.Model): - def __init__(self, obs_dim, act_dim, critic_in_dim): + def __init__(self, + obs_dim, + act_dim, + critic_in_dim, + continuous_actions=False): super(MAModel, self).__init__() - self.actor_model = ActorModel(obs_dim, act_dim) + self.actor_model = ActorModel(obs_dim, act_dim, continuous_actions) self.critic_model = CriticModel(critic_in_dim) def policy(self, obs): @@ -45,19 +49,26 @@ def get_critic_params(self): class ActorModel(parl.Model): - def __init__(self, obs_dim, act_dim): + def __init__(self, obs_dim, act_dim, continuous_actions=False): super(ActorModel, self).__init__() + self.continuous_actions = continuous_actions hid1_size = 64 hid2_size = 64 self.fc1 = nn.Linear(obs_dim, hid1_size) self.fc2 = nn.Linear(hid1_size, hid2_size) self.fc3 = nn.Linear(hid2_size, act_dim) + if self.continuous_actions: + std_hid_size = 64 + self.std_fc = nn.Linear(std_hid_size, act_dim) self.apply(weights_init_) def forward(self, obs): hid1 = F.relu(self.fc1(obs)) hid2 = F.relu(self.fc2(hid1)) means = self.fc3(hid2) + if self.continuous_actions: + act_std = self.std_fc(hid2) + return (means, act_std) return means diff --git a/benchmark/torch/maddpg/train.py b/benchmark/torch/maddpg/train.py index a3f175ecc..86b2da5db 100644 --- a/benchmark/torch/maddpg/train.py +++ b/benchmark/torch/maddpg/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,15 +19,15 @@ from simple_model import MAModel from simple_agent import MAAgent from parl.algorithms import MADDPG -from parl.env.multiagent_simple_env import MAenv +from parl.env.multiagent_env import MAenv from parl.utils import logger, summary +from gym import spaces CRITIC_LR = 0.01 # learning rate for the critic model ACTOR_LR = 0.01 # learning rate of the actor model GAMMA = 0.95 # reward discount factor TAU = 0.01 # soft update BATCH_SIZE = 1024 -MAX_EPISODES = 25000 # stop condition:number of episodes MAX_STEP_PER_EPISODE = 25 # maximum step per episode STAT_RATE = 1000 # statistical interval of save model or count reward @@ -79,36 +79,34 @@ def run_episode(env, agents): def train_agent(): - env = MAenv(args.env) + env = MAenv(args.env, args.continuous_actions) + if args.continuous_actions: + assert isinstance(env.action_space[0], spaces.Box) + + # print env info logger.info('agent num: {}'.format(env.n)) - logger.info('observation_space: {}'.format(env.observation_space)) - logger.info('action_space: {}'.format(env.action_space)) logger.info('obs_shape_n: {}'.format(env.obs_shape_n)) logger.info('act_shape_n: {}'.format(env.act_shape_n)) + logger.info('observation_space: {}'.format(env.observation_space)) + logger.info('action_space: {}'.format(env.action_space)) for i in range(env.n): logger.info('agent {} obs_low:{} obs_high:{}'.format( i, env.observation_space[i].low, env.observation_space[i].high)) logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i])) - if ('low' in dir(env.action_space[i])): + if (isinstance(env.action_space[i], spaces.Box)): logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format( i, env.action_space[i].low, env.action_space[i].high, env.action_space[i].shape)) - logger.info('num_discrete_space:{}'.format( - env.action_space[i].num_discrete_space)) - - from gym import spaces - from multiagent.multi_discrete import MultiDiscrete - for space in env.action_space: - assert (isinstance(space, spaces.Discrete) - or isinstance(space, MultiDiscrete)) critic_in_dim = sum(env.obs_shape_n) + sum(env.act_shape_n) logger.info('critic_in_dim: {}'.format(critic_in_dim)) + # build agents agents = [] for i in range(env.n): - model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim) + model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim, + args.continuous_actions) algorithm = MADDPG( model, agent_index=i, @@ -142,7 +140,7 @@ def train_agent(): t_start = time.time() logger.info('Starting...') - while total_episodes <= MAX_EPISODES: + while total_episodes <= args.max_episodes: # run an episode ep_reward, ep_agent_rewards, steps = run_episode(env, agents) summary.add_scalar('train_reward/episode', ep_reward, total_episodes) @@ -208,8 +206,20 @@ def train_agent(): type=str, default='./model', help='directory for saving model') + parser.add_argument( + '--continuous_actions', + action='store_true', + default=False, + help='use continuous action mode or not') + parser.add_argument( + '--max_episodes', + type=int, + default=25000, + help='the maximum number of episodes') + parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() + print('========== args: ', args) logger.set_dir('./train_log/' + str(args.env)) train_agent() diff --git a/examples/MADDPG/.benchmark/maddpg_paddle.png b/examples/MADDPG/.benchmark/maddpg_paddle.png deleted file mode 100644 index 244e9a625..000000000 Binary files a/examples/MADDPG/.benchmark/maddpg_paddle.png and /dev/null differ diff --git a/examples/MADDPG/README.md b/examples/MADDPG/README.md index 3f9ba2eb7..076bcbf15 100644 --- a/examples/MADDPG/README.md +++ b/examples/MADDPG/README.md @@ -10,7 +10,7 @@ A simple multi-agent particle world based on gym. Please see [here](https://gith Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).

-result +result

### Experiments result @@ -19,37 +19,37 @@ Mean episode reward (every 1000 episodes) in training process (totally 25000 epi simple
-MADDPG_simple +MADDPG_simple simple_adversary
-MADDPG_simple_adversary +MADDPG_simple_adversary simple_push
-MADDPG_simple_push +MADDPG_simple_push -simple_reference
-MADDPG_simple_reference +simple_crypto
+MADDPG_simple_crypto simple_speaker_listener
-MADDPG_simple_speaker_listener +MADDPG_simple_speaker_listener simple_spread
-MADDPG_simple_spread +MADDPG_simple_spread simple_tag
-MADDPG_simple_tag +MADDPG_simple_tag simple_world_comm
-MADDPG_simple_world_comm +MADDPG_simple_world_comm @@ -58,9 +58,10 @@ simple_world_comm
### Dependencies: + python3.5+ + [paddlepaddle>=2.0.0](https://github.com/PaddlePaddle/Paddle) -+ [parl>=2.0.2](https://github.com/PaddlePaddle/PARL) -+ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs) -+ gym==0.10.5 ++ [parl>=2.0.4](https://github.com/PaddlePaddle/PARL) ++ PettingZoo==1.17.0 ++ gym==0.23.1 + ### Start Training: ``` @@ -68,7 +69,12 @@ simple_world_comm
python train.py # To train for other scenario, model is automatically saved every 1000 episodes -# python train.py --env [ENV_NAME] +python train.py --env [ENV_NAME] # To show animation effects after training -# python train.py --env [ENV_NAME] --show --restore +python train.py --env [ENV_NAME] --show --restore + +# To train and evaluate scenarios with continuous action spaces +python train.py --env [ENV_NAME] --continuous_actions +python train.py --env [ENV_NAME] --continuous_actions --show --restore +``` diff --git a/examples/MADDPG/simple_agent.py b/examples/MADDPG/simple_agent.py index fbb837b82..7db79cf29 100644 --- a/examples/MADDPG/simple_agent.py +++ b/examples/MADDPG/simple_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ import paddle import numpy as np from parl.utils import ReplayMemory -from parl.utils import machine_info, get_gpu_count class MAAgent(parl.Agent): diff --git a/examples/MADDPG/simple_model.py b/examples/MADDPG/simple_model.py index 6fcb7cd7b..413dd94b0 100644 --- a/examples/MADDPG/simple_model.py +++ b/examples/MADDPG/simple_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,13 @@ class MAModel(parl.Model): - def __init__(self, obs_dim, act_dim, critic_in_dim): + def __init__(self, + obs_dim, + act_dim, + critic_in_dim, + continuous_actions=False): super(MAModel, self).__init__() - self.actor_model = ActorModel(obs_dim, act_dim) + self.actor_model = ActorModel(obs_dim, act_dim, continuous_actions) self.critic_model = CriticModel(critic_in_dim) def policy(self, obs): @@ -38,8 +42,9 @@ def get_critic_params(self): class ActorModel(parl.Model): - def __init__(self, obs_dim, act_dim): + def __init__(self, obs_dim, act_dim, continuous_actions=False): super(ActorModel, self).__init__() + self.continuous_actions = continuous_actions hid1_size = 64 hid2_size = 64 self.fc1 = nn.Linear( @@ -57,11 +62,21 @@ def __init__(self, obs_dim, act_dim): act_dim, weight_attr=paddle.ParamAttr( initializer=paddle.nn.initializer.XavierUniform())) + if self.continuous_actions: + std_hid_size = 64 + self.std_fc = nn.Linear( + std_hid_size, + act_dim, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform())) def forward(self, obs): hid1 = F.relu(self.fc1(obs)) hid2 = F.relu(self.fc2(hid1)) means = self.fc3(hid2) + if self.continuous_actions: + act_std = self.std_fc(hid2) + return (means, act_std) return means diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py index a3f175ecc..b38896a6a 100644 --- a/examples/MADDPG/train.py +++ b/examples/MADDPG/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,15 +19,15 @@ from simple_model import MAModel from simple_agent import MAAgent from parl.algorithms import MADDPG -from parl.env.multiagent_simple_env import MAenv +from parl.env.multiagent_env import MAenv from parl.utils import logger, summary +from gym import spaces CRITIC_LR = 0.01 # learning rate for the critic model ACTOR_LR = 0.01 # learning rate of the actor model GAMMA = 0.95 # reward discount factor TAU = 0.01 # soft update BATCH_SIZE = 1024 -MAX_EPISODES = 25000 # stop condition:number of episodes MAX_STEP_PER_EPISODE = 25 # maximum step per episode STAT_RATE = 1000 # statistical interval of save model or count reward @@ -79,36 +79,33 @@ def run_episode(env, agents): def train_agent(): - env = MAenv(args.env) + env = MAenv(args.env, args.continuous_actions) + if args.continuous_actions: + assert isinstance(env.action_space[0], spaces.Box) + + # print env info logger.info('agent num: {}'.format(env.n)) - logger.info('observation_space: {}'.format(env.observation_space)) - logger.info('action_space: {}'.format(env.action_space)) logger.info('obs_shape_n: {}'.format(env.obs_shape_n)) logger.info('act_shape_n: {}'.format(env.act_shape_n)) - + logger.info('observation_space: {}'.format(env.observation_space)) + logger.info('action_space: {}'.format(env.action_space)) for i in range(env.n): logger.info('agent {} obs_low:{} obs_high:{}'.format( i, env.observation_space[i].low, env.observation_space[i].high)) logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i])) - if ('low' in dir(env.action_space[i])): + if (isinstance(env.action_space[i], spaces.Box)): logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format( i, env.action_space[i].low, env.action_space[i].high, env.action_space[i].shape)) - logger.info('num_discrete_space:{}'.format( - env.action_space[i].num_discrete_space)) - - from gym import spaces - from multiagent.multi_discrete import MultiDiscrete - for space in env.action_space: - assert (isinstance(space, spaces.Discrete) - or isinstance(space, MultiDiscrete)) critic_in_dim = sum(env.obs_shape_n) + sum(env.act_shape_n) logger.info('critic_in_dim: {}'.format(critic_in_dim)) + # build agents agents = [] for i in range(env.n): - model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim) + model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim, + args.continuous_actions) algorithm = MADDPG( model, agent_index=i, @@ -142,7 +139,7 @@ def train_agent(): t_start = time.time() logger.info('Starting...') - while total_episodes <= MAX_EPISODES: + while total_episodes <= args.max_episodes: # run an episode ep_reward, ep_agent_rewards, steps = run_episode(env, agents) summary.add_scalar('train_reward/episode', ep_reward, total_episodes) @@ -208,8 +205,19 @@ def train_agent(): type=str, default='./model', help='directory for saving model') + parser.add_argument( + '--continuous_actions', + action='store_true', + default=False, + help='use continuous action mode or not') + parser.add_argument( + '--max_episodes', + type=int, + default=25000, + help='stop condition: number of episodes') args = parser.parse_args() + print('========== args: ', args) logger.set_dir('./train_log/' + str(args.env)) train_agent() diff --git a/parl/algorithms/paddle/maddpg.py b/parl/algorithms/paddle/maddpg.py index e25afe348..7906d8574 100644 --- a/parl/algorithms/paddle/maddpg.py +++ b/parl/algorithms/paddle/maddpg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ __all__ = ['MADDPG'] +from parl.core.paddle.policy_distribution import DiagGaussianDistribution from parl.core.paddle.policy_distribution import SoftCategoricalDistribution from parl.core.paddle.policy_distribution import SoftMultiCategoricalDistribution @@ -42,6 +43,9 @@ def SoftPDistribution(logits, act_space): elif (hasattr(act_space, 'num_discrete_space')): return SoftMultiCategoricalDistribution(logits, act_space.low, act_space.high) + # is instance of gym.spaces.Box + elif (hasattr(act_space, 'high')): + return DiagGaussianDistribution(logits) else: raise AssertionError("act_space must be instance of \ gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete") @@ -80,6 +84,11 @@ def __init__(self, assert isinstance(actor_lr, float) assert isinstance(critic_lr, float) + self.continuous_actions = False + if not len(act_space) == 0 and hasattr(act_space[0], 'high') \ + and not hasattr(act_space[0], 'num_discrete_space'): + self.continuous_actions = True + self.agent_index = agent_index self.act_space = act_space self.gamma = gamma @@ -117,6 +126,8 @@ def predict(self, obs, use_target_model=False): action = SoftPDistribution( logits=policy, act_space=self.act_space[self.agent_index]).sample() + if self.continuous_actions: + action = paddle.tanh(action) return action def Q(self, obs_n, act_n, use_target_model=False): @@ -150,6 +161,8 @@ def _actor_learn(self, obs_n, act_n): sample_this_action = SoftPDistribution( logits=this_policy, act_space=self.act_space[self.agent_index]).sample() + if self.continuous_actions: + sample_this_action = paddle.tanh(sample_this_action) # action_input_n = deepcopy(act_n) action_input_n = act_n + [] @@ -157,6 +170,9 @@ def _actor_learn(self, obs_n, act_n): eval_q = self.Q(obs_n, action_input_n) act_cost = paddle.mean(-1.0 * eval_q) + # when continuous, 'this_policy' will be a tuple with two element: (mean, std) + if self.continuous_actions: + this_policy = paddle.concat(this_policy, axis=-1) act_reg = paddle.mean(paddle.square(this_policy)) cost = act_cost + act_reg * 1e-3 diff --git a/parl/algorithms/torch/maddpg.py b/parl/algorithms/torch/maddpg.py index b9b9187de..70d2005b8 100644 --- a/parl/algorithms/torch/maddpg.py +++ b/parl/algorithms/torch/maddpg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ __all__ = ['MADDPG'] +from parl.core.torch.policy_distribution import DiagGaussianDistribution from parl.core.torch.policy_distribution import SoftCategoricalDistribution from parl.core.torch.policy_distribution import SoftMultiCategoricalDistribution @@ -29,7 +30,7 @@ def SoftPDistribution(logits, act_space): """ Select SoftCategoricalDistribution or SoftMultiCategoricalDistribution according to act_space. Args: - logits (paddle tensor): the output of policy model + logits (torch tensor): the output of policy model act_space: action space, must be gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete Returns: @@ -42,6 +43,10 @@ def SoftPDistribution(logits, act_space): elif (hasattr(act_space, 'num_discrete_space')): return SoftMultiCategoricalDistribution(logits, act_space.low, act_space.high) + # is instance of gym.spaces.Box + elif (hasattr(act_space, 'high')): + return DiagGaussianDistribution(logits) + else: raise AssertionError("act_space must be instance of \ gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete") @@ -80,6 +85,11 @@ def __init__(self, assert isinstance(actor_lr, float) assert isinstance(critic_lr, float) + self.continuous_actions = False + if not len(act_space) == 0 and hasattr(act_space[0], 'high') \ + and not hasattr(act_space[0], 'num_discrete_space'): + self.continuous_actions = True + self.agent_index = agent_index self.act_space = act_space self.gamma = gamma @@ -100,31 +110,37 @@ def __init__(self, def predict(self, obs, use_target_model=False): """ use the policy model to predict actions + Args: - obs (paddle tensor): observation, shape([B] + shape of obs_n[agent_index]) + obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index]) use_target_model (bool): use target_model or not - + Returns: - act (paddle tensor): action, shape([B] + shape of act_n[agent_index]) + act (torch tensor): action, shape([B] + shape of act_n[agent_index]) """ if use_target_model: policy = self.target_model.policy(obs) else: policy = self.model.policy(obs) + action = SoftPDistribution( logits=policy, act_space=self.act_space[self.agent_index]).sample() + if self.continuous_actions: + action = torch.tanh(action) + return action def Q(self, obs_n, act_n, use_target_model=False): """ use the value model to predict Q values - Args: - obs_n (list of paddle tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n) - act_n (list of paddle tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n) + + Args: + obs_n (list of torch tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n) + act_n (list of torch tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n) use_target_model (bool): use target_model or not Returns: - Q (paddle tensor): Q value of this agent, shape([B]) + Q (torch tensor): Q value of this agent, shape([B]) """ if use_target_model: return self.target_model.value(obs_n, act_n) @@ -146,6 +162,8 @@ def _actor_learn(self, obs_n, act_n): sample_this_action = SoftPDistribution( logits=this_policy, act_space=self.act_space[self.agent_index]).sample() + if self.continuous_actions: + sample_this_action = torch.tanh(sample_this_action) # action_input_n = deepcopy(act_n) action_input_n = act_n + [] @@ -153,6 +171,9 @@ def _actor_learn(self, obs_n, act_n): eval_q = self.Q(obs_n, action_input_n) act_cost = torch.mean(-1.0 * eval_q) + # when continuous, 'this_policy' will be a tuple with two element: (mean, std) + if self.continuous_actions: + this_policy = torch.cat(this_policy, dim=-1) act_reg = torch.mean(torch.square(this_policy)) cost = act_cost + act_reg * 1e-3 @@ -174,6 +195,12 @@ def _critic_learn(self, obs_n, act_n, target_q): return cost def sync_target(self, decay=None): + """ update the target network with the training network + + Args: + decay(float): the decaying factor while updating the target network with the training network. + 0 represents the **assignment**. None represents updating the target network slowly that depends on the hyperparameter `tau`. + """ if decay is None: decay = 1.0 - self.tau self.model.sync_weights_to(self.target_model, decay=decay) diff --git a/parl/core/paddle/policy_distribution.py b/parl/core/paddle/policy_distribution.py index 95ccbf4ca..608772014 100644 --- a/parl/core/paddle/policy_distribution.py +++ b/parl/core/paddle/policy_distribution.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ import paddle import paddle.nn.functional as F +import numpy as np __all__ = [ 'PolicyDistribution', 'CategoricalDistribution', @@ -39,6 +40,78 @@ def logp(self, actions): raise NotImplementedError +class DiagGaussianDistribution(PolicyDistribution): + """DiagGaussian distribution for continuous action spaces.""" + + def __init__(self, logits): + """ + Args: + logits: A tuple of (mean, logstd) + mean: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits + logstd: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] + """ + assert len(logits) == 2 + assert len(logits[0].shape) == 2 and len(logits[1].shape) == 2 + self.logits = logits + (mean, logstd) = logits + self.mean = mean + self.logstd = logstd + + self.std = paddle.exp(self.logstd) + + def sample(self): + """ + Returns: + sample_action: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action, + with noise to keep the target close to the original action. + """ + mean_shape = paddle.to_tensor(self.mean.shape, dtype='int64') + random_normal = paddle.normal(shape=mean_shape) + return self.mean + self.std * random_normal + + def entropy(self): + """ + Returns: + entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution. + """ + entropy = paddle.sum( + self.logstd + 0.5 * np.log(2.0 * np.pi * np.e), axis=1) + return entropy + + def logp(self, actions): + """ + Args: + actions: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] + + Returns: + actions_log_prob: A float32 tensor with shape [BATCH_SIZE] + """ + assert len(actions.shape) == 2 + + norm_actions = paddle.sum( + paddle.square((actions - self.mean) / self.std), axis=1) + pi_item = 0.5 * np.log(2.0 * np.pi) * actions.shape[1] + actions_log_prob = -0.5 * norm_actions - 0.5 * pi_item - paddle.sum( + self.logstd, axis=1) + + return actions_log_prob + + def kl(self, other): + """ + Args: + other: object of DiagGaussianDistribution + + Returns: + kl: A float32 tensor with shape [BATCH_SIZE] + """ + assert isinstance(other, DiagGaussianDistribution) + + temp = (paddle.square(self.std) + paddle.square(self.mean - other.mean) + ) / (2.0 * paddle.square(other.std)) + kl = paddle.sum(other.logstd - self.logstd + temp - 0.5, axis=1) + return kl + + class CategoricalDistribution(PolicyDistribution): """Categorical distribution for discrete action spaces.""" diff --git a/parl/core/torch/policy_distribution.py b/parl/core/torch/policy_distribution.py index 2efe15b27..77a6a41d0 100644 --- a/parl/core/torch/policy_distribution.py +++ b/parl/core/torch/policy_distribution.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ import torch import torch.nn.functional as F +import numpy as np __all__ = [ 'PolicyDistribution', 'CategoricalDistribution', @@ -39,6 +40,77 @@ def logp(self, actions): raise NotImplementedError +class DiagGaussianDistribution(PolicyDistribution): + """DiagGaussian distribution for continuous action spaces.""" + + def __init__(self, logits): + """ + Args: + logits: A tuple of (mean, logstd) + mean: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits + logstd: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] + """ + assert len(logits) == 2 + assert len(logits[0].shape) == 2 and len(logits[1].shape) == 2 + self.logits = logits + (mean, logstd) = logits + self.mean = mean + self.logstd = logstd + + self.std = torch.exp(self.logstd) + + def sample(self): + """ + Returns: + sample_action: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action, + with noise to keep the target close to the original action. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + random_normal = torch.randn(size=self.mean.shape).to(device) + return self.mean + self.std * random_normal + + def entropy(self): + """ + Returns: + entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution. + """ + entropy = torch.sum( + self.logstd + 0.5 * np.log(2.0 * np.pi * np.e), axis=1) + return entropy + + def logp(self, actions): + """ + Args: + actions: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] + Returns: + actions_log_prob: A float32 tensor with shape [BATCH_SIZE] + """ + assert len(actions.shape) == 2 + + norm_actions = torch.sum( + torch.square((actions - self.mean) / self.std), axis=1) + actions_shape = torch.to_tensor(actions.shape, dtype=torch.float32) + pi_item = 0.5 * np.log(2.0 * np.pi) * actions_shape[1] + actions_log_prob = -0.5 * norm_actions - pi_item - torch.sum( + self.logstd, axis=1) + + return actions_log_prob + + def kl(self, other): + """ + Args: + other: object of DiagGaussianDistribution + Returns: + kl: A float32 tensor with shape [BATCH_SIZE] + """ + assert isinstance(other, DiagGaussianDistribution) + + temp = (torch.square(self.std) + torch.square(self.mean - other.mean) + ) / (2.0 * torch.square(other.std)) + kl = torch.sum(other.logstd - self.logstd + temp - 0.5, axis=1) + return kl + + class CategoricalDistribution(PolicyDistribution): """Categorical distribution for discrete action spaces.""" diff --git a/parl/env/multiagent_env.py b/parl/env/multiagent_env.py new file mode 100644 index 000000000..b44c206c7 --- /dev/null +++ b/parl/env/multiagent_env.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +try: + import gym + from gym import spaces + from pettingzoo.mpe import simple_v2 + from pettingzoo.mpe import simple_adversary_v2 + from pettingzoo.mpe import simple_crypto_v2 + from pettingzoo.mpe import simple_push_v2 + from pettingzoo.mpe import simple_speaker_listener_v3 + from pettingzoo.mpe import simple_spread_v2 + from pettingzoo.mpe import simple_tag_v2 + from pettingzoo.mpe import simple_world_comm_v2 +except: + raise ImportError('Can not use MAenv from parl.env.multiagent_env. \n \ + try `pip install PettingZoo==1.17.0` and `pip install gym==0.23.1` \n \ + (PettingZoo 1.17.0 requires gym>=0.21.0)') + + +def MAenv(scenario_name, continuous_actions=False): + env_list = [ + 'simple', 'simple_adversary', 'simple_crypto', 'simple_push', + 'simple_speaker_listener', 'simple_spread', 'simple_tag', + 'simple_world_comm' + ] + assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format( + scenario_name, env_list) + if scenario_name == 'simple': + env = simple_v2.parallel_env( + max_cycles=25, continuous_actions=continuous_actions) + elif scenario_name == 'simple_adversary': + env = simple_adversary_v2.parallel_env( + N=2, max_cycles=25, continuous_actions=continuous_actions) + elif scenario_name == 'simple_crypto': + env = simple_crypto_v2.parallel_env( + max_cycles=25, continuous_actions=continuous_actions) + elif scenario_name == 'simple_push': + env = simple_push_v2.parallel_env( + max_cycles=25, continuous_actions=continuous_actions) + elif scenario_name == 'simple_speaker_listener': + env = simple_speaker_listener_v3.parallel_env( + max_cycles=25, continuous_actions=continuous_actions) + elif scenario_name == 'simple_spread': + env = simple_spread_v2.parallel_env( + N=3, + local_ratio=0, + max_cycles=25, + continuous_actions=continuous_actions) + elif scenario_name == 'simple_tag': + env = simple_tag_v2.parallel_env( + num_good=1, + num_adversaries=3, + num_obstacles=2, + max_cycles=25, + continuous_actions=continuous_actions) + elif scenario_name == 'simple_world_comm': + env = simple_world_comm_v2.parallel_env( + num_good=2, + num_adversaries=4, + num_obstacles=1, + num_food=2, + max_cycles=25, + num_forests=2, + continuous_actions=continuous_actions) + else: + pass + + env = mpe_wrapper_for_pettingzoo(env, continuous_actions) + return env + + +class mpe_wrapper_for_pettingzoo(gym.Wrapper): + def __init__(self, env=None, continuous_actions=False): + gym.Wrapper.__init__(self, env) + self.continuous_actions = continuous_actions + self.observation_space = list(self.observation_spaces.values()) + self.action_space = list(self.action_spaces.values()) + assert len(self.observation_space) == len(self.action_space) + self.n = len(self.observation_space) + self.agents_name = list(self.observation_spaces.keys()) + self.obs_shape_n = [ + self.get_shape(self.observation_space[i]) for i in range(self.n) + ] + self.act_shape_n = [ + self.get_shape(self.action_space[i]) for i in range(self.n) + ] + + def get_shape(self, input_space): + """ + Args: + input_space: environment space + + Returns: + space shape + """ + if (isinstance(input_space, spaces.Box)): + if (len(input_space.shape) == 1): + return input_space.shape[0] + else: + return input_space.shape + elif (isinstance(input_space, spaces.Discrete)): + return input_space.n + else: + print('[Error] shape is {}, not Box or Discrete'.format( + input_space.shape)) + raise NotImplementedError + + def reset(self): + obs = self.env.reset() + return list(obs.values()) + + def step(self, actions): + actions_dict = dict() + for i, act in enumerate(actions): + agent = self.agents_name[i] + if self.continuous_actions: + assert np.all(((act<=1.0 + 1e-3), (act>=-1.0 - 1e-3))), \ + 'the action should be in range [-1.0, 1.0], but got {}'.format(act) + high = self.action_space[i].high + low = self.action_space[i].low + mapped_action = low + (act - (-1.0)) * ((high - low) / 2.0) + mapped_action = np.clip(mapped_action, low, high) + actions_dict[agent] = mapped_action + else: + actions_dict[agent] = np.argmax(act) + obs, reward, done, info = self.env.step(actions_dict) + return list(obs.values()), list(reward.values()), list( + done.values()), list(info.values()) diff --git a/parl/env/multiagent_simple_env.py b/parl/env/multiagent_simple_env.py index 913ebda3e..af9115866 100644 --- a/parl/env/multiagent_simple_env.py +++ b/parl/env/multiagent_simple_env.py @@ -12,56 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -from gym import spaces -from multiagent.multi_discrete import MultiDiscrete -from multiagent.environment import MultiAgentEnv -import multiagent.scenarios as scenarios +try: + from gym import spaces + from multiagent.multi_discrete import MultiDiscrete + from multiagent.environment import MultiAgentEnv + import multiagent.scenarios as scenarios + from parl.utils import logger + class MAenv(MultiAgentEnv): + """ multiagent environment warppers for maddpg + """ -class MAenv(MultiAgentEnv): - """ multiagent environment warppers for maddpg - """ - - def __init__(self, scenario_name): - env_list = [ - 'simple', 'simple_adversary', 'simple_crypto', 'simple_push', - 'simple_reference', 'simple_speaker_listener', 'simple_spread', - 'simple_tag', 'simple_world_comm' - ] - assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format( - scenario_name, env_list) - # load scenario from script - scenario = scenarios.load(scenario_name + ".py").Scenario() - # create world - world = scenario.make_world() - # initial multiagent environment - super().__init__(world, scenario.reset_world, scenario.reward, - scenario.observation) - self.obs_shape_n = [ - self.get_shape(self.observation_space[i]) for i in range(self.n) - ] - self.act_shape_n = [ - self.get_shape(self.action_space[i]) for i in range(self.n) - ] + def __init__(self, scenario_name): + env_list = [ + 'simple', 'simple_adversary', 'simple_crypto', 'simple_push', + 'simple_reference', 'simple_speaker_listener', 'simple_spread', + 'simple_tag', 'simple_world_comm' + ] + assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format( + scenario_name, env_list) + # load scenario from script + scenario = scenarios.load(scenario_name + ".py").Scenario() + # create world + world = scenario.make_world() + # initial multiagent environment + super().__init__(world, scenario.reset_world, scenario.reward, + scenario.observation) + self.obs_shape_n = [ + self.get_shape(self.observation_space[i]) + for i in range(self.n) + ] + self.act_shape_n = [ + self.get_shape(self.action_space[i]) for i in range(self.n) + ] - def get_shape(self, input_space): - """ - Args: - input_space: environment space + def get_shape(self, input_space): + """ + Args: + input_space: environment space - Returns: - space shape - """ - if (isinstance(input_space, spaces.Box)): - if (len(input_space.shape) == 1): - return input_space.shape[0] + Returns: + space shape + """ + if (isinstance(input_space, spaces.Box)): + if (len(input_space.shape) == 1): + return input_space.shape[0] + else: + return input_space.shape + elif (isinstance(input_space, spaces.Discrete)): + return input_space.n + elif (isinstance(input_space, MultiDiscrete)): + return sum(input_space.high - input_space.low + 1) else: - return input_space.shape - elif (isinstance(input_space, spaces.Discrete)): - return input_space.n - elif (isinstance(input_space, MultiDiscrete)): - return sum(input_space.high - input_space.low + 1) - else: - print('[Error] shape is {}, not Box or Discrete or MultiDiscrete'. - format(input_space.shape)) - raise NotImplementedError + print( + '[Error] shape is {}, not Box or Discrete or MultiDiscrete' + .format(input_space.shape)) + raise NotImplementedError + + logger.warning( + 'the `MAenv` from `parl.env.multiagent_simple_env` is deprecated since parl 2.0.4 and will be removed in parl 3.0. \n \ + We recomend you to use `from parl.env.multiagent_env import MAenv` instead, which supports continuous control.' + ) +except: + raise ImportError( + 'Can not use MAenv from parl.env.multiagent_simple_env, \n \ + please pip install multiagent according to https://github.com/openai/multiagent-particle-envs \ + as well as `pip install gym==0.10.5`')