diff --git a/.gitignore b/.gitignore
index 0dd122d3d..408324403 100644
--- a/.gitignore
+++ b/.gitignore
@@ -90,7 +90,6 @@ celerybeat-schedule
# virtualenv
.venv
venv/
-ENV/
# Spyder project settings
.spyderproject
@@ -103,4 +102,4 @@ ENV/
/site
# mypy
-.mypy_cache/
+.mypy_cache/
\ No newline at end of file
diff --git a/benchmark/torch/maddpg/.benchmark/maddpg_torch.png b/benchmark/torch/maddpg/.benchmark/maddpg_torch.png
deleted file mode 100644
index ddd6e86e2..000000000
Binary files a/benchmark/torch/maddpg/.benchmark/maddpg_torch.png and /dev/null differ
diff --git a/benchmark/torch/maddpg/README.md b/benchmark/torch/maddpg/README.md
index b76439712..0a5e41a14 100644
--- a/benchmark/torch/maddpg/README.md
+++ b/benchmark/torch/maddpg/README.md
@@ -10,7 +10,7 @@ A simple multi-agent particle world based on gym. Please see [here](https://gith
Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).
-
+
### Experiments result
@@ -19,37 +19,37 @@ Mean episode reward (every 1000 episodes) in training process (totally 25000 epi
simple
-
+
|
simple_adversary
-
+
|
simple_push
-
+
|
-simple_reference
-
+simple_crypto
+
|
simple_speaker_listener
-
+
|
simple_spread
-
+
|
simple_tag
-
+
|
simple_world_comm
-
+
|
@@ -58,9 +58,9 @@ simple_world_comm
### Dependencies:
+ python3.5+
+ torch
-+ [parl>=2.0.2](https://github.com/PaddlePaddle/PARL)
-+ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs)
-+ gym==0.10.5
++ [parl>=2.0.4](https://github.com/PaddlePaddle/PARL)
++ PettingZoo==1.17.0
++ gym==0.23.1
### Start Training:
```
@@ -68,7 +68,11 @@ simple_world_comm
python train.py
# To train for other scenario, model is automatically saved every 1000 episodes
-# python train.py --env [ENV_NAME]
+python train.py --env [ENV_NAME]
# To show animation effects after training
-# python train.py --env [ENV_NAME] --show --restore
+python train.py --env [ENV_NAME] --show --restore
+
+# To train and evaluate scenarios with continuous action spaces
+python train.py --env [ENV_NAME] --continuous_actions
+python train.py --env [ENV_NAME] --continuous_actions --show --restore
diff --git a/benchmark/torch/maddpg/simple_agent.py b/benchmark/torch/maddpg/simple_agent.py
index 6cfa91008..f4e1a3688 100644
--- a/benchmark/torch/maddpg/simple_agent.py
+++ b/benchmark/torch/maddpg/simple_agent.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/benchmark/torch/maddpg/simple_model.py b/benchmark/torch/maddpg/simple_model.py
index 5ac9f8151..3ae90b7cc 100644
--- a/benchmark/torch/maddpg/simple_model.py
+++ b/benchmark/torch/maddpg/simple_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,9 +26,13 @@ def weights_init_(m):
class MAModel(parl.Model):
- def __init__(self, obs_dim, act_dim, critic_in_dim):
+ def __init__(self,
+ obs_dim,
+ act_dim,
+ critic_in_dim,
+ continuous_actions=False):
super(MAModel, self).__init__()
- self.actor_model = ActorModel(obs_dim, act_dim)
+ self.actor_model = ActorModel(obs_dim, act_dim, continuous_actions)
self.critic_model = CriticModel(critic_in_dim)
def policy(self, obs):
@@ -45,19 +49,26 @@ def get_critic_params(self):
class ActorModel(parl.Model):
- def __init__(self, obs_dim, act_dim):
+ def __init__(self, obs_dim, act_dim, continuous_actions=False):
super(ActorModel, self).__init__()
+ self.continuous_actions = continuous_actions
hid1_size = 64
hid2_size = 64
self.fc1 = nn.Linear(obs_dim, hid1_size)
self.fc2 = nn.Linear(hid1_size, hid2_size)
self.fc3 = nn.Linear(hid2_size, act_dim)
+ if self.continuous_actions:
+ std_hid_size = 64
+ self.std_fc = nn.Linear(std_hid_size, act_dim)
self.apply(weights_init_)
def forward(self, obs):
hid1 = F.relu(self.fc1(obs))
hid2 = F.relu(self.fc2(hid1))
means = self.fc3(hid2)
+ if self.continuous_actions:
+ act_std = self.std_fc(hid2)
+ return (means, act_std)
return means
diff --git a/benchmark/torch/maddpg/train.py b/benchmark/torch/maddpg/train.py
index a3f175ecc..86b2da5db 100644
--- a/benchmark/torch/maddpg/train.py
+++ b/benchmark/torch/maddpg/train.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,15 +19,15 @@
from simple_model import MAModel
from simple_agent import MAAgent
from parl.algorithms import MADDPG
-from parl.env.multiagent_simple_env import MAenv
+from parl.env.multiagent_env import MAenv
from parl.utils import logger, summary
+from gym import spaces
CRITIC_LR = 0.01 # learning rate for the critic model
ACTOR_LR = 0.01 # learning rate of the actor model
GAMMA = 0.95 # reward discount factor
TAU = 0.01 # soft update
BATCH_SIZE = 1024
-MAX_EPISODES = 25000 # stop condition:number of episodes
MAX_STEP_PER_EPISODE = 25 # maximum step per episode
STAT_RATE = 1000 # statistical interval of save model or count reward
@@ -79,36 +79,34 @@ def run_episode(env, agents):
def train_agent():
- env = MAenv(args.env)
+ env = MAenv(args.env, args.continuous_actions)
+ if args.continuous_actions:
+ assert isinstance(env.action_space[0], spaces.Box)
+
+ # print env info
logger.info('agent num: {}'.format(env.n))
- logger.info('observation_space: {}'.format(env.observation_space))
- logger.info('action_space: {}'.format(env.action_space))
logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
logger.info('act_shape_n: {}'.format(env.act_shape_n))
+ logger.info('observation_space: {}'.format(env.observation_space))
+ logger.info('action_space: {}'.format(env.action_space))
for i in range(env.n):
logger.info('agent {} obs_low:{} obs_high:{}'.format(
i, env.observation_space[i].low, env.observation_space[i].high))
logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i]))
- if ('low' in dir(env.action_space[i])):
+ if (isinstance(env.action_space[i], spaces.Box)):
logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format(
i, env.action_space[i].low, env.action_space[i].high,
env.action_space[i].shape))
- logger.info('num_discrete_space:{}'.format(
- env.action_space[i].num_discrete_space))
-
- from gym import spaces
- from multiagent.multi_discrete import MultiDiscrete
- for space in env.action_space:
- assert (isinstance(space, spaces.Discrete)
- or isinstance(space, MultiDiscrete))
critic_in_dim = sum(env.obs_shape_n) + sum(env.act_shape_n)
logger.info('critic_in_dim: {}'.format(critic_in_dim))
+ # build agents
agents = []
for i in range(env.n):
- model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim)
+ model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim,
+ args.continuous_actions)
algorithm = MADDPG(
model,
agent_index=i,
@@ -142,7 +140,7 @@ def train_agent():
t_start = time.time()
logger.info('Starting...')
- while total_episodes <= MAX_EPISODES:
+ while total_episodes <= args.max_episodes:
# run an episode
ep_reward, ep_agent_rewards, steps = run_episode(env, agents)
summary.add_scalar('train_reward/episode', ep_reward, total_episodes)
@@ -208,8 +206,20 @@ def train_agent():
type=str,
default='./model',
help='directory for saving model')
+ parser.add_argument(
+ '--continuous_actions',
+ action='store_true',
+ default=False,
+ help='use continuous action mode or not')
+ parser.add_argument(
+ '--max_episodes',
+ type=int,
+ default=25000,
+ help='the maximum number of episodes')
+ parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
+ print('========== args: ', args)
logger.set_dir('./train_log/' + str(args.env))
train_agent()
diff --git a/examples/MADDPG/.benchmark/maddpg_paddle.png b/examples/MADDPG/.benchmark/maddpg_paddle.png
deleted file mode 100644
index 244e9a625..000000000
Binary files a/examples/MADDPG/.benchmark/maddpg_paddle.png and /dev/null differ
diff --git a/examples/MADDPG/README.md b/examples/MADDPG/README.md
index 3f9ba2eb7..076bcbf15 100644
--- a/examples/MADDPG/README.md
+++ b/examples/MADDPG/README.md
@@ -10,7 +10,7 @@ A simple multi-agent particle world based on gym. Please see [here](https://gith
Mean episode reward (every 1000 episodes) in training process (totally 25000 episodes).
-
+
### Experiments result
@@ -19,37 +19,37 @@ Mean episode reward (every 1000 episodes) in training process (totally 25000 epi
simple
-
+
|
simple_adversary
-
+
|
simple_push
-
+
|
-simple_reference
-
+simple_crypto
+
|
simple_speaker_listener
-
+
|
simple_spread
-
+
|
simple_tag
-
+
|
simple_world_comm
-
+
|
@@ -58,9 +58,10 @@ simple_world_comm
### Dependencies:
+ python3.5+
+ [paddlepaddle>=2.0.0](https://github.com/PaddlePaddle/Paddle)
-+ [parl>=2.0.2](https://github.com/PaddlePaddle/PARL)
-+ [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs)
-+ gym==0.10.5
++ [parl>=2.0.4](https://github.com/PaddlePaddle/PARL)
++ PettingZoo==1.17.0
++ gym==0.23.1
+
### Start Training:
```
@@ -68,7 +69,12 @@ simple_world_comm
python train.py
# To train for other scenario, model is automatically saved every 1000 episodes
-# python train.py --env [ENV_NAME]
+python train.py --env [ENV_NAME]
# To show animation effects after training
-# python train.py --env [ENV_NAME] --show --restore
+python train.py --env [ENV_NAME] --show --restore
+
+# To train and evaluate scenarios with continuous action spaces
+python train.py --env [ENV_NAME] --continuous_actions
+python train.py --env [ENV_NAME] --continuous_actions --show --restore
+```
diff --git a/examples/MADDPG/simple_agent.py b/examples/MADDPG/simple_agent.py
index fbb837b82..7db79cf29 100644
--- a/examples/MADDPG/simple_agent.py
+++ b/examples/MADDPG/simple_agent.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
import paddle
import numpy as np
from parl.utils import ReplayMemory
-from parl.utils import machine_info, get_gpu_count
class MAAgent(parl.Agent):
diff --git a/examples/MADDPG/simple_model.py b/examples/MADDPG/simple_model.py
index 6fcb7cd7b..413dd94b0 100644
--- a/examples/MADDPG/simple_model.py
+++ b/examples/MADDPG/simple_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,9 +19,13 @@
class MAModel(parl.Model):
- def __init__(self, obs_dim, act_dim, critic_in_dim):
+ def __init__(self,
+ obs_dim,
+ act_dim,
+ critic_in_dim,
+ continuous_actions=False):
super(MAModel, self).__init__()
- self.actor_model = ActorModel(obs_dim, act_dim)
+ self.actor_model = ActorModel(obs_dim, act_dim, continuous_actions)
self.critic_model = CriticModel(critic_in_dim)
def policy(self, obs):
@@ -38,8 +42,9 @@ def get_critic_params(self):
class ActorModel(parl.Model):
- def __init__(self, obs_dim, act_dim):
+ def __init__(self, obs_dim, act_dim, continuous_actions=False):
super(ActorModel, self).__init__()
+ self.continuous_actions = continuous_actions
hid1_size = 64
hid2_size = 64
self.fc1 = nn.Linear(
@@ -57,11 +62,21 @@ def __init__(self, obs_dim, act_dim):
act_dim,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierUniform()))
+ if self.continuous_actions:
+ std_hid_size = 64
+ self.std_fc = nn.Linear(
+ std_hid_size,
+ act_dim,
+ weight_attr=paddle.ParamAttr(
+ initializer=paddle.nn.initializer.XavierUniform()))
def forward(self, obs):
hid1 = F.relu(self.fc1(obs))
hid2 = F.relu(self.fc2(hid1))
means = self.fc3(hid2)
+ if self.continuous_actions:
+ act_std = self.std_fc(hid2)
+ return (means, act_std)
return means
diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py
index a3f175ecc..b38896a6a 100644
--- a/examples/MADDPG/train.py
+++ b/examples/MADDPG/train.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,15 +19,15 @@
from simple_model import MAModel
from simple_agent import MAAgent
from parl.algorithms import MADDPG
-from parl.env.multiagent_simple_env import MAenv
+from parl.env.multiagent_env import MAenv
from parl.utils import logger, summary
+from gym import spaces
CRITIC_LR = 0.01 # learning rate for the critic model
ACTOR_LR = 0.01 # learning rate of the actor model
GAMMA = 0.95 # reward discount factor
TAU = 0.01 # soft update
BATCH_SIZE = 1024
-MAX_EPISODES = 25000 # stop condition:number of episodes
MAX_STEP_PER_EPISODE = 25 # maximum step per episode
STAT_RATE = 1000 # statistical interval of save model or count reward
@@ -79,36 +79,33 @@ def run_episode(env, agents):
def train_agent():
- env = MAenv(args.env)
+ env = MAenv(args.env, args.continuous_actions)
+ if args.continuous_actions:
+ assert isinstance(env.action_space[0], spaces.Box)
+
+ # print env info
logger.info('agent num: {}'.format(env.n))
- logger.info('observation_space: {}'.format(env.observation_space))
- logger.info('action_space: {}'.format(env.action_space))
logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
logger.info('act_shape_n: {}'.format(env.act_shape_n))
-
+ logger.info('observation_space: {}'.format(env.observation_space))
+ logger.info('action_space: {}'.format(env.action_space))
for i in range(env.n):
logger.info('agent {} obs_low:{} obs_high:{}'.format(
i, env.observation_space[i].low, env.observation_space[i].high))
logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i]))
- if ('low' in dir(env.action_space[i])):
+ if (isinstance(env.action_space[i], spaces.Box)):
logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format(
i, env.action_space[i].low, env.action_space[i].high,
env.action_space[i].shape))
- logger.info('num_discrete_space:{}'.format(
- env.action_space[i].num_discrete_space))
-
- from gym import spaces
- from multiagent.multi_discrete import MultiDiscrete
- for space in env.action_space:
- assert (isinstance(space, spaces.Discrete)
- or isinstance(space, MultiDiscrete))
critic_in_dim = sum(env.obs_shape_n) + sum(env.act_shape_n)
logger.info('critic_in_dim: {}'.format(critic_in_dim))
+ # build agents
agents = []
for i in range(env.n):
- model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim)
+ model = MAModel(env.obs_shape_n[i], env.act_shape_n[i], critic_in_dim,
+ args.continuous_actions)
algorithm = MADDPG(
model,
agent_index=i,
@@ -142,7 +139,7 @@ def train_agent():
t_start = time.time()
logger.info('Starting...')
- while total_episodes <= MAX_EPISODES:
+ while total_episodes <= args.max_episodes:
# run an episode
ep_reward, ep_agent_rewards, steps = run_episode(env, agents)
summary.add_scalar('train_reward/episode', ep_reward, total_episodes)
@@ -208,8 +205,19 @@ def train_agent():
type=str,
default='./model',
help='directory for saving model')
+ parser.add_argument(
+ '--continuous_actions',
+ action='store_true',
+ default=False,
+ help='use continuous action mode or not')
+ parser.add_argument(
+ '--max_episodes',
+ type=int,
+ default=25000,
+ help='stop condition: number of episodes')
args = parser.parse_args()
+ print('========== args: ', args)
logger.set_dir('./train_log/' + str(args.env))
train_agent()
diff --git a/parl/algorithms/paddle/maddpg.py b/parl/algorithms/paddle/maddpg.py
index e25afe348..7906d8574 100644
--- a/parl/algorithms/paddle/maddpg.py
+++ b/parl/algorithms/paddle/maddpg.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
__all__ = ['MADDPG']
+from parl.core.paddle.policy_distribution import DiagGaussianDistribution
from parl.core.paddle.policy_distribution import SoftCategoricalDistribution
from parl.core.paddle.policy_distribution import SoftMultiCategoricalDistribution
@@ -42,6 +43,9 @@ def SoftPDistribution(logits, act_space):
elif (hasattr(act_space, 'num_discrete_space')):
return SoftMultiCategoricalDistribution(logits, act_space.low,
act_space.high)
+ # is instance of gym.spaces.Box
+ elif (hasattr(act_space, 'high')):
+ return DiagGaussianDistribution(logits)
else:
raise AssertionError("act_space must be instance of \
gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete")
@@ -80,6 +84,11 @@ def __init__(self,
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
+ self.continuous_actions = False
+ if not len(act_space) == 0 and hasattr(act_space[0], 'high') \
+ and not hasattr(act_space[0], 'num_discrete_space'):
+ self.continuous_actions = True
+
self.agent_index = agent_index
self.act_space = act_space
self.gamma = gamma
@@ -117,6 +126,8 @@ def predict(self, obs, use_target_model=False):
action = SoftPDistribution(
logits=policy,
act_space=self.act_space[self.agent_index]).sample()
+ if self.continuous_actions:
+ action = paddle.tanh(action)
return action
def Q(self, obs_n, act_n, use_target_model=False):
@@ -150,6 +161,8 @@ def _actor_learn(self, obs_n, act_n):
sample_this_action = SoftPDistribution(
logits=this_policy,
act_space=self.act_space[self.agent_index]).sample()
+ if self.continuous_actions:
+ sample_this_action = paddle.tanh(sample_this_action)
# action_input_n = deepcopy(act_n)
action_input_n = act_n + []
@@ -157,6 +170,9 @@ def _actor_learn(self, obs_n, act_n):
eval_q = self.Q(obs_n, action_input_n)
act_cost = paddle.mean(-1.0 * eval_q)
+ # when continuous, 'this_policy' will be a tuple with two element: (mean, std)
+ if self.continuous_actions:
+ this_policy = paddle.concat(this_policy, axis=-1)
act_reg = paddle.mean(paddle.square(this_policy))
cost = act_cost + act_reg * 1e-3
diff --git a/parl/algorithms/torch/maddpg.py b/parl/algorithms/torch/maddpg.py
index b9b9187de..70d2005b8 100644
--- a/parl/algorithms/torch/maddpg.py
+++ b/parl/algorithms/torch/maddpg.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
__all__ = ['MADDPG']
+from parl.core.torch.policy_distribution import DiagGaussianDistribution
from parl.core.torch.policy_distribution import SoftCategoricalDistribution
from parl.core.torch.policy_distribution import SoftMultiCategoricalDistribution
@@ -29,7 +30,7 @@ def SoftPDistribution(logits, act_space):
""" Select SoftCategoricalDistribution or SoftMultiCategoricalDistribution according to act_space.
Args:
- logits (paddle tensor): the output of policy model
+ logits (torch tensor): the output of policy model
act_space: action space, must be gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete
Returns:
@@ -42,6 +43,10 @@ def SoftPDistribution(logits, act_space):
elif (hasattr(act_space, 'num_discrete_space')):
return SoftMultiCategoricalDistribution(logits, act_space.low,
act_space.high)
+ # is instance of gym.spaces.Box
+ elif (hasattr(act_space, 'high')):
+ return DiagGaussianDistribution(logits)
+
else:
raise AssertionError("act_space must be instance of \
gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete")
@@ -80,6 +85,11 @@ def __init__(self,
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
+ self.continuous_actions = False
+ if not len(act_space) == 0 and hasattr(act_space[0], 'high') \
+ and not hasattr(act_space[0], 'num_discrete_space'):
+ self.continuous_actions = True
+
self.agent_index = agent_index
self.act_space = act_space
self.gamma = gamma
@@ -100,31 +110,37 @@ def __init__(self,
def predict(self, obs, use_target_model=False):
""" use the policy model to predict actions
+
Args:
- obs (paddle tensor): observation, shape([B] + shape of obs_n[agent_index])
+ obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index])
use_target_model (bool): use target_model or not
-
+
Returns:
- act (paddle tensor): action, shape([B] + shape of act_n[agent_index])
+ act (torch tensor): action, shape([B] + shape of act_n[agent_index])
"""
if use_target_model:
policy = self.target_model.policy(obs)
else:
policy = self.model.policy(obs)
+
action = SoftPDistribution(
logits=policy,
act_space=self.act_space[self.agent_index]).sample()
+ if self.continuous_actions:
+ action = torch.tanh(action)
+
return action
def Q(self, obs_n, act_n, use_target_model=False):
""" use the value model to predict Q values
- Args:
- obs_n (list of paddle tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)
- act_n (list of paddle tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n)
+
+ Args:
+ obs_n (list of torch tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)
+ act_n (list of torch tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n)
use_target_model (bool): use target_model or not
Returns:
- Q (paddle tensor): Q value of this agent, shape([B])
+ Q (torch tensor): Q value of this agent, shape([B])
"""
if use_target_model:
return self.target_model.value(obs_n, act_n)
@@ -146,6 +162,8 @@ def _actor_learn(self, obs_n, act_n):
sample_this_action = SoftPDistribution(
logits=this_policy,
act_space=self.act_space[self.agent_index]).sample()
+ if self.continuous_actions:
+ sample_this_action = torch.tanh(sample_this_action)
# action_input_n = deepcopy(act_n)
action_input_n = act_n + []
@@ -153,6 +171,9 @@ def _actor_learn(self, obs_n, act_n):
eval_q = self.Q(obs_n, action_input_n)
act_cost = torch.mean(-1.0 * eval_q)
+ # when continuous, 'this_policy' will be a tuple with two element: (mean, std)
+ if self.continuous_actions:
+ this_policy = torch.cat(this_policy, dim=-1)
act_reg = torch.mean(torch.square(this_policy))
cost = act_cost + act_reg * 1e-3
@@ -174,6 +195,12 @@ def _critic_learn(self, obs_n, act_n, target_q):
return cost
def sync_target(self, decay=None):
+ """ update the target network with the training network
+
+ Args:
+ decay(float): the decaying factor while updating the target network with the training network.
+ 0 represents the **assignment**. None represents updating the target network slowly that depends on the hyperparameter `tau`.
+ """
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(self.target_model, decay=decay)
diff --git a/parl/core/paddle/policy_distribution.py b/parl/core/paddle/policy_distribution.py
index 95ccbf4ca..608772014 100644
--- a/parl/core/paddle/policy_distribution.py
+++ b/parl/core/paddle/policy_distribution.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
import paddle
import paddle.nn.functional as F
+import numpy as np
__all__ = [
'PolicyDistribution', 'CategoricalDistribution',
@@ -39,6 +40,78 @@ def logp(self, actions):
raise NotImplementedError
+class DiagGaussianDistribution(PolicyDistribution):
+ """DiagGaussian distribution for continuous action spaces."""
+
+ def __init__(self, logits):
+ """
+ Args:
+ logits: A tuple of (mean, logstd)
+ mean: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
+ logstd: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS]
+ """
+ assert len(logits) == 2
+ assert len(logits[0].shape) == 2 and len(logits[1].shape) == 2
+ self.logits = logits
+ (mean, logstd) = logits
+ self.mean = mean
+ self.logstd = logstd
+
+ self.std = paddle.exp(self.logstd)
+
+ def sample(self):
+ """
+ Returns:
+ sample_action: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
+ with noise to keep the target close to the original action.
+ """
+ mean_shape = paddle.to_tensor(self.mean.shape, dtype='int64')
+ random_normal = paddle.normal(shape=mean_shape)
+ return self.mean + self.std * random_normal
+
+ def entropy(self):
+ """
+ Returns:
+ entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
+ """
+ entropy = paddle.sum(
+ self.logstd + 0.5 * np.log(2.0 * np.pi * np.e), axis=1)
+ return entropy
+
+ def logp(self, actions):
+ """
+ Args:
+ actions: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS]
+
+ Returns:
+ actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
+ """
+ assert len(actions.shape) == 2
+
+ norm_actions = paddle.sum(
+ paddle.square((actions - self.mean) / self.std), axis=1)
+ pi_item = 0.5 * np.log(2.0 * np.pi) * actions.shape[1]
+ actions_log_prob = -0.5 * norm_actions - 0.5 * pi_item - paddle.sum(
+ self.logstd, axis=1)
+
+ return actions_log_prob
+
+ def kl(self, other):
+ """
+ Args:
+ other: object of DiagGaussianDistribution
+
+ Returns:
+ kl: A float32 tensor with shape [BATCH_SIZE]
+ """
+ assert isinstance(other, DiagGaussianDistribution)
+
+ temp = (paddle.square(self.std) + paddle.square(self.mean - other.mean)
+ ) / (2.0 * paddle.square(other.std))
+ kl = paddle.sum(other.logstd - self.logstd + temp - 0.5, axis=1)
+ return kl
+
+
class CategoricalDistribution(PolicyDistribution):
"""Categorical distribution for discrete action spaces."""
diff --git a/parl/core/torch/policy_distribution.py b/parl/core/torch/policy_distribution.py
index 2efe15b27..77a6a41d0 100644
--- a/parl/core/torch/policy_distribution.py
+++ b/parl/core/torch/policy_distribution.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
import torch
import torch.nn.functional as F
+import numpy as np
__all__ = [
'PolicyDistribution', 'CategoricalDistribution',
@@ -39,6 +40,77 @@ def logp(self, actions):
raise NotImplementedError
+class DiagGaussianDistribution(PolicyDistribution):
+ """DiagGaussian distribution for continuous action spaces."""
+
+ def __init__(self, logits):
+ """
+ Args:
+ logits: A tuple of (mean, logstd)
+ mean: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
+ logstd: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS]
+ """
+ assert len(logits) == 2
+ assert len(logits[0].shape) == 2 and len(logits[1].shape) == 2
+ self.logits = logits
+ (mean, logstd) = logits
+ self.mean = mean
+ self.logstd = logstd
+
+ self.std = torch.exp(self.logstd)
+
+ def sample(self):
+ """
+ Returns:
+ sample_action: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS] of sample action,
+ with noise to keep the target close to the original action.
+ """
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ random_normal = torch.randn(size=self.mean.shape).to(device)
+ return self.mean + self.std * random_normal
+
+ def entropy(self):
+ """
+ Returns:
+ entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
+ """
+ entropy = torch.sum(
+ self.logstd + 0.5 * np.log(2.0 * np.pi * np.e), axis=1)
+ return entropy
+
+ def logp(self, actions):
+ """
+ Args:
+ actions: An float32 tensor with shape [BATCH_SIZE, NUM_ACTIOINS]
+ Returns:
+ actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
+ """
+ assert len(actions.shape) == 2
+
+ norm_actions = torch.sum(
+ torch.square((actions - self.mean) / self.std), axis=1)
+ actions_shape = torch.to_tensor(actions.shape, dtype=torch.float32)
+ pi_item = 0.5 * np.log(2.0 * np.pi) * actions_shape[1]
+ actions_log_prob = -0.5 * norm_actions - pi_item - torch.sum(
+ self.logstd, axis=1)
+
+ return actions_log_prob
+
+ def kl(self, other):
+ """
+ Args:
+ other: object of DiagGaussianDistribution
+ Returns:
+ kl: A float32 tensor with shape [BATCH_SIZE]
+ """
+ assert isinstance(other, DiagGaussianDistribution)
+
+ temp = (torch.square(self.std) + torch.square(self.mean - other.mean)
+ ) / (2.0 * torch.square(other.std))
+ kl = torch.sum(other.logstd - self.logstd + temp - 0.5, axis=1)
+ return kl
+
+
class CategoricalDistribution(PolicyDistribution):
"""Categorical distribution for discrete action spaces."""
diff --git a/parl/env/multiagent_env.py b/parl/env/multiagent_env.py
new file mode 100644
index 000000000..b44c206c7
--- /dev/null
+++ b/parl/env/multiagent_env.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+try:
+ import gym
+ from gym import spaces
+ from pettingzoo.mpe import simple_v2
+ from pettingzoo.mpe import simple_adversary_v2
+ from pettingzoo.mpe import simple_crypto_v2
+ from pettingzoo.mpe import simple_push_v2
+ from pettingzoo.mpe import simple_speaker_listener_v3
+ from pettingzoo.mpe import simple_spread_v2
+ from pettingzoo.mpe import simple_tag_v2
+ from pettingzoo.mpe import simple_world_comm_v2
+except:
+ raise ImportError('Can not use MAenv from parl.env.multiagent_env. \n \
+ try `pip install PettingZoo==1.17.0` and `pip install gym==0.23.1` \n \
+ (PettingZoo 1.17.0 requires gym>=0.21.0)')
+
+
+def MAenv(scenario_name, continuous_actions=False):
+ env_list = [
+ 'simple', 'simple_adversary', 'simple_crypto', 'simple_push',
+ 'simple_speaker_listener', 'simple_spread', 'simple_tag',
+ 'simple_world_comm'
+ ]
+ assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format(
+ scenario_name, env_list)
+ if scenario_name == 'simple':
+ env = simple_v2.parallel_env(
+ max_cycles=25, continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_adversary':
+ env = simple_adversary_v2.parallel_env(
+ N=2, max_cycles=25, continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_crypto':
+ env = simple_crypto_v2.parallel_env(
+ max_cycles=25, continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_push':
+ env = simple_push_v2.parallel_env(
+ max_cycles=25, continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_speaker_listener':
+ env = simple_speaker_listener_v3.parallel_env(
+ max_cycles=25, continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_spread':
+ env = simple_spread_v2.parallel_env(
+ N=3,
+ local_ratio=0,
+ max_cycles=25,
+ continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_tag':
+ env = simple_tag_v2.parallel_env(
+ num_good=1,
+ num_adversaries=3,
+ num_obstacles=2,
+ max_cycles=25,
+ continuous_actions=continuous_actions)
+ elif scenario_name == 'simple_world_comm':
+ env = simple_world_comm_v2.parallel_env(
+ num_good=2,
+ num_adversaries=4,
+ num_obstacles=1,
+ num_food=2,
+ max_cycles=25,
+ num_forests=2,
+ continuous_actions=continuous_actions)
+ else:
+ pass
+
+ env = mpe_wrapper_for_pettingzoo(env, continuous_actions)
+ return env
+
+
+class mpe_wrapper_for_pettingzoo(gym.Wrapper):
+ def __init__(self, env=None, continuous_actions=False):
+ gym.Wrapper.__init__(self, env)
+ self.continuous_actions = continuous_actions
+ self.observation_space = list(self.observation_spaces.values())
+ self.action_space = list(self.action_spaces.values())
+ assert len(self.observation_space) == len(self.action_space)
+ self.n = len(self.observation_space)
+ self.agents_name = list(self.observation_spaces.keys())
+ self.obs_shape_n = [
+ self.get_shape(self.observation_space[i]) for i in range(self.n)
+ ]
+ self.act_shape_n = [
+ self.get_shape(self.action_space[i]) for i in range(self.n)
+ ]
+
+ def get_shape(self, input_space):
+ """
+ Args:
+ input_space: environment space
+
+ Returns:
+ space shape
+ """
+ if (isinstance(input_space, spaces.Box)):
+ if (len(input_space.shape) == 1):
+ return input_space.shape[0]
+ else:
+ return input_space.shape
+ elif (isinstance(input_space, spaces.Discrete)):
+ return input_space.n
+ else:
+ print('[Error] shape is {}, not Box or Discrete'.format(
+ input_space.shape))
+ raise NotImplementedError
+
+ def reset(self):
+ obs = self.env.reset()
+ return list(obs.values())
+
+ def step(self, actions):
+ actions_dict = dict()
+ for i, act in enumerate(actions):
+ agent = self.agents_name[i]
+ if self.continuous_actions:
+ assert np.all(((act<=1.0 + 1e-3), (act>=-1.0 - 1e-3))), \
+ 'the action should be in range [-1.0, 1.0], but got {}'.format(act)
+ high = self.action_space[i].high
+ low = self.action_space[i].low
+ mapped_action = low + (act - (-1.0)) * ((high - low) / 2.0)
+ mapped_action = np.clip(mapped_action, low, high)
+ actions_dict[agent] = mapped_action
+ else:
+ actions_dict[agent] = np.argmax(act)
+ obs, reward, done, info = self.env.step(actions_dict)
+ return list(obs.values()), list(reward.values()), list(
+ done.values()), list(info.values())
diff --git a/parl/env/multiagent_simple_env.py b/parl/env/multiagent_simple_env.py
index 913ebda3e..af9115866 100644
--- a/parl/env/multiagent_simple_env.py
+++ b/parl/env/multiagent_simple_env.py
@@ -12,56 +12,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from gym import spaces
-from multiagent.multi_discrete import MultiDiscrete
-from multiagent.environment import MultiAgentEnv
-import multiagent.scenarios as scenarios
+try:
+ from gym import spaces
+ from multiagent.multi_discrete import MultiDiscrete
+ from multiagent.environment import MultiAgentEnv
+ import multiagent.scenarios as scenarios
+ from parl.utils import logger
+ class MAenv(MultiAgentEnv):
+ """ multiagent environment warppers for maddpg
+ """
-class MAenv(MultiAgentEnv):
- """ multiagent environment warppers for maddpg
- """
-
- def __init__(self, scenario_name):
- env_list = [
- 'simple', 'simple_adversary', 'simple_crypto', 'simple_push',
- 'simple_reference', 'simple_speaker_listener', 'simple_spread',
- 'simple_tag', 'simple_world_comm'
- ]
- assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format(
- scenario_name, env_list)
- # load scenario from script
- scenario = scenarios.load(scenario_name + ".py").Scenario()
- # create world
- world = scenario.make_world()
- # initial multiagent environment
- super().__init__(world, scenario.reset_world, scenario.reward,
- scenario.observation)
- self.obs_shape_n = [
- self.get_shape(self.observation_space[i]) for i in range(self.n)
- ]
- self.act_shape_n = [
- self.get_shape(self.action_space[i]) for i in range(self.n)
- ]
+ def __init__(self, scenario_name):
+ env_list = [
+ 'simple', 'simple_adversary', 'simple_crypto', 'simple_push',
+ 'simple_reference', 'simple_speaker_listener', 'simple_spread',
+ 'simple_tag', 'simple_world_comm'
+ ]
+ assert scenario_name in env_list, 'Env {} not found (valid envs include {})'.format(
+ scenario_name, env_list)
+ # load scenario from script
+ scenario = scenarios.load(scenario_name + ".py").Scenario()
+ # create world
+ world = scenario.make_world()
+ # initial multiagent environment
+ super().__init__(world, scenario.reset_world, scenario.reward,
+ scenario.observation)
+ self.obs_shape_n = [
+ self.get_shape(self.observation_space[i])
+ for i in range(self.n)
+ ]
+ self.act_shape_n = [
+ self.get_shape(self.action_space[i]) for i in range(self.n)
+ ]
- def get_shape(self, input_space):
- """
- Args:
- input_space: environment space
+ def get_shape(self, input_space):
+ """
+ Args:
+ input_space: environment space
- Returns:
- space shape
- """
- if (isinstance(input_space, spaces.Box)):
- if (len(input_space.shape) == 1):
- return input_space.shape[0]
+ Returns:
+ space shape
+ """
+ if (isinstance(input_space, spaces.Box)):
+ if (len(input_space.shape) == 1):
+ return input_space.shape[0]
+ else:
+ return input_space.shape
+ elif (isinstance(input_space, spaces.Discrete)):
+ return input_space.n
+ elif (isinstance(input_space, MultiDiscrete)):
+ return sum(input_space.high - input_space.low + 1)
else:
- return input_space.shape
- elif (isinstance(input_space, spaces.Discrete)):
- return input_space.n
- elif (isinstance(input_space, MultiDiscrete)):
- return sum(input_space.high - input_space.low + 1)
- else:
- print('[Error] shape is {}, not Box or Discrete or MultiDiscrete'.
- format(input_space.shape))
- raise NotImplementedError
+ print(
+ '[Error] shape is {}, not Box or Discrete or MultiDiscrete'
+ .format(input_space.shape))
+ raise NotImplementedError
+
+ logger.warning(
+ 'the `MAenv` from `parl.env.multiagent_simple_env` is deprecated since parl 2.0.4 and will be removed in parl 3.0. \n \
+ We recomend you to use `from parl.env.multiagent_env import MAenv` instead, which supports continuous control.'
+ )
+except:
+ raise ImportError(
+ 'Can not use MAenv from parl.env.multiagent_simple_env, \n \
+ please pip install multiagent according to https://github.com/openai/multiagent-particle-envs \
+ as well as `pip install gym==0.10.5`')