From faf7884f0c1861cb3b02c310ffa98bb83c489fe9 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 12:25:12 +0200 Subject: [PATCH 01/11] Renamed current RandomPolicy to MARLRandomPolicy --- docs/01_tutorials/04_tictactoe.rst | 29 ++++++++++++++++++++++++++--- test/pettingzoo/tic_tac_toe.py | 9 +++++++-- tianshou/policy/__init__.py | 4 ++-- tianshou/policy/random.py | 10 +++++----- 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index b15b11ace..4247c539d 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -122,11 +122,34 @@ Two Random Agents .. Figure:: ../_static/images/marl.png -Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.RandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. +Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.MARLRandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. :: >>> from tianshou.data import Collector + >>> from tianshou.env import DummyVectorEnv + >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager + >>> + >>> # agents should be wrapped into one policy, + >>> # which is responsible for calling the acting agent correctly + >>> # here we use two random agents + >>> policy = MultiAgentPolicyManager( + >>> [MARLRandomPolicy(action_space=env.action_space), RandomPolicy(action_space=env.action_space)], env + >>> ) + >>> + >>> # need to vectorize the environment for the collector + >>> env = DummyVectorEnv([lambda: env]) + >>> + >>> # use collectors to collect a episode of trajectories + >>> # the reward is a vector, so we need a scalar metric to monitor the training + >>> collector = Collector(policy, env) + >>> + >>> # you will see a long trajectory showing the board status at each timestep + >>> result = collector.collect(n_episode=1, render=.1) + (only show the last 3 steps) + | | + X | X | - + >>> from tianshou.env import DummyVectorEnv >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager >>> @@ -202,7 +225,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul BasePolicy, DQNPolicy, MultiAgentPolicyManager, - RandomPolicy, + MARLRandomPolicy, ) from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -286,7 +309,7 @@ The following ``get_agents`` function returns agents and their optimizers from e - The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; - The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; -- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. +- The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 966c9e04c..59a5e45be 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -13,7 +13,12 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy +from tianshou.policy import ( + BasePolicy, + DQNPolicy, + MARLRandomPolicy, + MultiAgentPolicyManager, +) from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -131,7 +136,7 @@ def get_agents( agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: - agent_opponent = RandomPolicy(action_space=env.action_space) + agent_opponent = MARLRandomPolicy(action_space=env.action_space) if args.agent_id == 1: agents = [agent_learn, agent_opponent] diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 5e6967ad7..a9b944da8 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -2,7 +2,7 @@ # isort:skip_file from tianshou.policy.base import BasePolicy, TrainingStats -from tianshou.policy.random import RandomPolicy +from tianshou.policy.random import MARLRandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy from tianshou.policy.modelfree.c51 import C51Policy @@ -34,7 +34,7 @@ __all__ = [ "BasePolicy", - "RandomPolicy", + "MARLRandomPolicy", "DQNPolicy", "BranchingDQNPolicy", "C51Policy", diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 943ae99f2..bf665bc2b 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -9,14 +9,14 @@ from tianshou.policy.base import TrainingStats -class RandomTrainingStats(TrainingStats): +class MARLRandomTrainingStats(TrainingStats): pass -TRandomTrainingStats = TypeVar("TRandomTrainingStats", bound=RandomTrainingStats) +TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) -class RandomPolicy(BasePolicy[TRandomTrainingStats]): +class MARLRandomPolicy(BasePolicy[TMARLRandomTrainingStats]): """A random agent used in multi-agent learning. It randomly chooses an action from the legal action. @@ -49,6 +49,6 @@ def forward( result = Batch(act=logits.argmax(axis=-1)) return cast(ActBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TRandomTrainingStats: # type: ignore + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TMARLRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" - return RandomTrainingStats() # type: ignore[return-value] + return MARLRandomTrainingStats() # type: ignore[return-value] From d0be0010e71fef8812dc913940fdce891d905c36 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 13:03:55 +0200 Subject: [PATCH 02/11] Added RandomActor, RandomActionPolicy and support for it in hl-interfaces --- tianshou/highlevel/agent.py | 6 +++++ tianshou/policy/base.py | 25 ++++++++++++++++++++ tianshou/utils/net/common.py | 44 ++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 81141a8a6..9fffa3ad5 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -59,6 +59,7 @@ TD3Policy, TRPOPolicy, ) +from tianshou.policy.base import RandomActionPolicy from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net.common import ActorCritic @@ -245,6 +246,11 @@ def create_trainer( ) +class RandomActionAgentFactory(OnPolicyAgentFactory): + def _create_policy(self, envs: Environments, device: TDevice) -> RandomActionPolicy: + return RandomActionPolicy(envs.get_action_space()) + + class PGAgentFactory(OnPolicyAgentFactory): def __init__( self, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index d886180a5..f2c027fa2 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,6 +25,7 @@ RolloutBatchProtocol, ) from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode @@ -693,6 +694,30 @@ def _compile() -> None: _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) +class RandomActionPolicy(BasePolicy): + def __init__( + self, + action_space: gym.Space, + ) -> None: + super().__init__(action_space=action_space) + if not isinstance(action_space, gym.spaces.Discrete | gym.spaces.Box): + raise NotImplementedError( + f"RandomActionPolicy currently only supports Discrete and Box action spaces, but got {action_space}.", + ) + self.actor = RandomActor(action_space) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + return cast(ActBatchProtocol, Batch(act=self.actor(batch.obs))) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + return TrainingStats() + + # TODO: rename? See docstring @njit def _gae_return( diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index eceee100f..7510ae7e7 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -4,10 +4,12 @@ import numpy as np import torch +from gymnasium import spaces from torch import nn from tianshou.data.batch import Batch from tianshou.data.types import RecurrentStateBatch +from tianshou.utils.space_info import ActionSpaceInfo ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] @@ -632,6 +634,48 @@ def forward( pass +class RandomActor(BaseActor): + """An actor that returns random actions. + + For continuous action spaces, forward returns a batch of random actions sampled from the action space. + For discrete action spaces, forward returns a batch of n-dimensional arrays corresponding to the + uniform distribution over the n possible actions (same interface as in :class:`~.net.discrete.Actor`). + """ + + def __init__(self, action_space: spaces.Box | spaces.Discrete) -> None: + super().__init__() + self._action_space = action_space + self._space_info = ActionSpaceInfo.from_space(action_space) + + @property + def action_space(self) -> spaces.Box | spaces.Discrete: + return self._action_space + + @property + def space_info(self) -> ActionSpaceInfo: + return self._space_info + + def get_preprocess_net(self) -> nn.Module: + return nn.Identity() + + def get_output_dim(self) -> int: + return self.space_info.action_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, Any | None]: + batch_size = len(obs) + if isinstance(self.action_space, spaces.Box): + action = np.stack([self.action_space.sample() for _ in range(batch_size)]) + else: + # Discrete Actors currently return an n-dimensional array of probabilities for each action + action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) + return action, state + + def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: """Gets the given attribute from the given object or takes the alternative value if it is not present. If both are present, they are required to match. From ad1bdc1c2f81581672d768927e93bd101b7ae9cb Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 13:17:34 +0200 Subject: [PATCH 03/11] Added RandomActionExperimentBuilder --- tianshou/highlevel/experiment.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 74413997b..e5b9536ed 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -19,7 +19,7 @@ import os import pickle -from abc import abstractmethod +from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import suppress from copy import deepcopy @@ -48,6 +48,7 @@ NPGAgentFactory, PGAgentFactory, PPOAgentFactory, + RandomActionAgentFactory, REDQAgentFactory, SACAgentFactory, TD3AgentFactory, @@ -478,7 +479,7 @@ def run( return launcher.launch(experiments=self.experiments) -class ExperimentBuilder: +class ExperimentBuilder(ABC): """A helper class (following the builder pattern) for creating experiments. It contains a lot of defaults for the setup which can be adjusted using the @@ -676,6 +677,13 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: return ExperimentCollection(seeded_experiments) +class RandomActionExperimentBuilder(ExperimentBuilder): + def _create_agent_factory(self) -> RandomActionAgentFactory: + return RandomActionAgentFactory( + sampling_config=self.sampling_config, optim_factory=self._get_optim_factory() + ) + + class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def __init__(self, continuous_actor_type: ContinuousActorType): self._continuous_actor_type = continuous_actor_type From c699e1ddab78c17279ddcff3b2041b185e5bc054 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 13:35:09 +0200 Subject: [PATCH 04/11] Formatting [ci skip] --- tianshou/highlevel/experiment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e5b9536ed..7531616a5 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -680,7 +680,8 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: class RandomActionExperimentBuilder(ExperimentBuilder): def _create_agent_factory(self) -> RandomActionAgentFactory: return RandomActionAgentFactory( - sampling_config=self.sampling_config, optim_factory=self._get_optim_factory() + sampling_config=self.sampling_config, + optim_factory=self._get_optim_factory(), ) From c8ef74e2db2adaa4de6faa00568d3ecccf9eeffd Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 14:37:20 +0200 Subject: [PATCH 05/11] RandomActionPolicy: fix unpacking of act, state. Formatting --- docs/01_tutorials/04_tictactoe.rst | 44 +++++++++++++++--------------- tianshou/policy/base.py | 5 ++-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index 4247c539d..fb7f490ce 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -127,28 +127,28 @@ Tianshou already provides some builtin classes for multi-agent learning. You can :: >>> from tianshou.data import Collector - >>> from tianshou.env import DummyVectorEnv - >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager - >>> - >>> # agents should be wrapped into one policy, - >>> # which is responsible for calling the acting agent correctly - >>> # here we use two random agents - >>> policy = MultiAgentPolicyManager( - >>> [MARLRandomPolicy(action_space=env.action_space), RandomPolicy(action_space=env.action_space)], env - >>> ) - >>> - >>> # need to vectorize the environment for the collector - >>> env = DummyVectorEnv([lambda: env]) - >>> - >>> # use collectors to collect a episode of trajectories - >>> # the reward is a vector, so we need a scalar metric to monitor the training - >>> collector = Collector(policy, env) - >>> - >>> # you will see a long trajectory showing the board status at each timestep - >>> result = collector.collect(n_episode=1, render=.1) - (only show the last 3 steps) - | | - X | X | - + >>> from tianshou.env import DummyVectorEnv + >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager + >>> + >>> # agents should be wrapped into one policy, + >>> # which is responsible for calling the acting agent correctly + >>> # here we use two random agents + >>> policy = MultiAgentPolicyManager( + >>> [MARLRandomPolicy(action_space=env.action_space), RandomPolicy(action_space=env.action_space)], env + >>> ) + >>> + >>> # need to vectorize the environment for the collector + >>> env = DummyVectorEnv([lambda: env]) + >>> + >>> # use collectors to collect a episode of trajectories + >>> # the reward is a vector, so we need a scalar metric to monitor the training + >>> collector = Collector(policy, env) + >>> + >>> # you will see a long trajectory showing the board status at each timestep + >>> result = collector.collect(n_episode=1, render=.1) + (only show the last 3 steps) + | | + X | X | - >>> from tianshou.env import DummyVectorEnv >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f2c027fa2..fb3c71e0f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -711,8 +711,9 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> ActBatchProtocol: - return cast(ActBatchProtocol, Batch(act=self.actor(batch.obs))) + ) -> ActStateBatchProtocol: + act, next_state = self.actor(batch.obs, state) + return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: return TrainingStats() From dcf1b2edf021596e3593a70c4b21b8f89ea437ce Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 Aug 2024 11:31:56 +0200 Subject: [PATCH 06/11] RandomActionPolicy: fixes for discrete case, added tests --- test/base/test_policy.py | 44 ++++++++++++++++++++++++++++++++++++ tianshou/policy/base.py | 2 +- tianshou/utils/net/common.py | 16 +++++++++++-- 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 4d26905c3..8911bab16 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -4,7 +4,9 @@ import torch from torch.distributions import Categorical, Distribution, Independent, Normal +from tianshou.data import Batch from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.policy.base import RandomActionPolicy from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -77,3 +79,45 @@ def test_get_action(self, policy: PPOPolicy) -> None: actions = [policy.compute_action(sample_obs) for _ in range(10)] # check that the actions are the same in deterministic mode assert len(set(map(_to_hashable, actions))) == 1 + + @staticmethod + def test_random_policy_discrete_actions() -> None: + action_space = gym.spaces.Discrete(3) + policy = RandomActionPolicy(action_space=action_space) + + # forward of actor returns discrete probabilities, in compliance with the overall discrete actor + action_probs = policy.actor(np.zeros((10, 2)))[0] + assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3))) + + actions = [] + for _ in range(10): + action = policy.compute_action(np.array([0])) + assert action_space.contains(action) + actions.append(action) + + # not all actions are the same + assert len(set(actions)) > 1 + + # test batched forward + action_batch = policy(Batch(obs=np.zeros((10, 2)))) + assert action_batch.act.shape == (10,) + assert len(set(action_batch.act.tolist())) > 1 + + @staticmethod + def test_random_policy_continuous_actions() -> None: + action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) + policy = RandomActionPolicy(action_space=action_space) + + actions = [] + for _ in range(10): + action = policy.compute_action(np.array([0])) + assert action_space.contains(action) + actions.append(action) + + # not all actions are the same + assert len(set(map(_to_hashable, actions))) > 1 + + # test batched forward + action_batch = policy(Batch(obs=np.zeros((10, 2)))) + assert action_batch.act.shape == (10, 3) + assert len(set(map(_to_hashable, action_batch.act))) > 1 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index fb3c71e0f..d21010bf0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -712,7 +712,7 @@ def forward( state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActStateBatchProtocol: - act, next_state = self.actor(batch.obs, state) + act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7510ae7e7..243a04093 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -7,7 +7,7 @@ from gymnasium import spaces from torch import nn -from tianshou.data.batch import Batch +from tianshou.data.batch import Batch, BatchProtocol from tianshou.data.types import RecurrentStateBatch from tianshou.utils.space_info import ActionSpaceInfo @@ -661,9 +661,13 @@ def get_preprocess_net(self) -> nn.Module: def get_output_dim(self) -> int: return self.space_info.action_dim + @property + def is_discrete(self) -> bool: + return isinstance(self.action_space, spaces.Discrete) + def forward( self, - obs: np.ndarray | torch.Tensor, + obs: np.ndarray | torch.Tensor | BatchProtocol, state: Any | None = None, info: dict[str, Any] | None = None, ) -> tuple[np.ndarray, Any | None]: @@ -675,6 +679,14 @@ def forward( action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) return action, state + def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray: + if self.is_discrete: + # Different from forward which returns discrete probabilities, see comment there + assert isinstance(self.action_space, spaces.Discrete) # for mypy + return np.random.randint(low=0, high=self.action_space.n, size=len(obs)) + else: + return self.forward(obs)[0] + def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: """Gets the given attribute from the given object or takes the alternative value if it is not present. From 1eaf276dae5d188233e22b35e79535c356fcd2f0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 17:38:09 +0200 Subject: [PATCH 07/11] Docs, copy-paste error [ci skip] --- docs/01_tutorials/04_tictactoe.rst | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index fb7f490ce..60387d30a 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -127,29 +127,6 @@ Tianshou already provides some builtin classes for multi-agent learning. You can :: >>> from tianshou.data import Collector - >>> from tianshou.env import DummyVectorEnv - >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager - >>> - >>> # agents should be wrapped into one policy, - >>> # which is responsible for calling the acting agent correctly - >>> # here we use two random agents - >>> policy = MultiAgentPolicyManager( - >>> [MARLRandomPolicy(action_space=env.action_space), RandomPolicy(action_space=env.action_space)], env - >>> ) - >>> - >>> # need to vectorize the environment for the collector - >>> env = DummyVectorEnv([lambda: env]) - >>> - >>> # use collectors to collect a episode of trajectories - >>> # the reward is a vector, so we need a scalar metric to monitor the training - >>> collector = Collector(policy, env) - >>> - >>> # you will see a long trajectory showing the board status at each timestep - >>> result = collector.collect(n_episode=1, render=.1) - (only show the last 3 steps) - | | - X | X | - - >>> from tianshou.env import DummyVectorEnv >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager >>> From 6f8648ab141e05170be3aae9545997493d03e8b0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 24 Aug 2024 14:03:40 +0200 Subject: [PATCH 08/11] Batch: added possibility to change shape to atleast_2d, including distributions --- test/base/test_batch.py | 60 ++++++++++++++++++++++++++++++++++++++++- tianshou/data/batch.py | 39 ++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index bb43cd682..5fa40758f 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,10 +9,11 @@ import pytest import torch from deepdiff import DeepDiff +from torch.distributions import Distribution, Independent, Normal from torch.distributions.categorical import Categorical from tianshou.data import Batch, to_numpy, to_torch -from tianshou.data.batch import IndexType, get_sliced_dist +from tianshou.data.batch import IndexType, dist_to_atleast_2d, get_sliced_dist def test_batch() -> None: @@ -766,6 +767,63 @@ def test_batch_over_batch_to_torch() -> None: assert batch.b.d.dtype == torch.float32 assert batch.b.e.dtype == torch.float32 + @staticmethod + @pytest.mark.parametrize( + "dist, expected_batch_shape", + [ + (Categorical(probs=torch.tensor([0.3, 0.7])), (1,)), + (Categorical(probs=torch.tensor([[0.3, 0.7], [0.4, 0.6]])), (2,)), + (Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), (1,)), + (Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), (2,)), + (Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), (1,)), + ( + Independent( + Normal(loc=torch.tensor([0.0, 1.0]), scale=torch.tensor([1.0, 2.0])), + 0, + ), + (2,), + ), + ], + ) + def test_dist_to_atleast_2d(dist: Distribution, expected_batch_shape: tuple[int]) -> None: + result = dist_to_atleast_2d(dist) + assert result.batch_shape == expected_batch_shape + + # Additionally check that the parameters are correctly transformed + if isinstance(dist, Categorical): + assert isinstance(result, Categorical) + assert result.probs.shape[:-1] == expected_batch_shape + elif isinstance(dist, Normal): + assert isinstance(result, Normal) + assert result.loc.shape == expected_batch_shape + assert result.scale.shape == expected_batch_shape + elif isinstance(dist, Independent): + assert isinstance(result, Independent) + assert result.base_dist.batch_shape == expected_batch_shape + + @staticmethod + @pytest.mark.parametrize( + "dist", + [ + Categorical(probs=torch.tensor([0.3, 0.7])), + Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), + Independent(Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)), 0), + ], + ) + def test_dist_to_atleast_2d_idempotent(dist: Distribution) -> None: + result1 = dist_to_atleast_2d(dist) + result2 = dist_to_atleast_2d(result1) + assert result1 == result2 + + @staticmethod + def test_batch_to_atleast_2d() -> None: + scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3))) + assert scalar_batch.dist.batch_shape == () + assert scalar_batch.a.shape == scalar_batch.b.shape == () + scalar_batch_2d = scalar_batch.to_at_least_2d() + assert scalar_batch_2d.dist.batch_shape == (1,) + assert scalar_batch_2d.a.shape == scalar_batch_2d.b.shape == (1, 1) + class TestAssignment: @staticmethod diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index aeae5f3ff..70478a87d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -285,6 +285,23 @@ def get_len_of_dist(dist: Distribution) -> int: return dist.batch_shape[0] +def dist_to_atleast_2d(dist: TDistribution) -> TDistribution: + """Convert a distribution to at least 2D, such that the `batch_shape` attribute has a len of at least 1.""" + if len(dist.batch_shape) > 0: + return dist + if isinstance(dist, Categorical): + return Categorical(probs=dist.probs.unsqueeze(0)) # type: ignore[return-value] + elif isinstance(dist, Normal): + return Normal(loc=dist.loc.unsqueeze(0), scale=dist.scale.unsqueeze(0)) # type: ignore[return-value] + elif isinstance(dist, Independent): + return Independent( + dist_to_atleast_2d(dist.base_dist), + dist.reinterpreted_batch_ndims, + ) # type: ignore[return-value] + else: + raise NotImplementedError(f"Unsupported distribution for conversion to 2D: {type(dist)}") + + # Note: This is implemented as a protocol because the interface # of Batch is always extended by adding new fields. Having a hierarchy of # protocols building off this one allows for type safety and IDE support despite @@ -602,6 +619,14 @@ def get(self, key: str, default: Any | None = None) -> Any: def pop(self, key: str, default: Any | None = None) -> Any: raise ProtocolCalledException + def to_at_least_2d(self) -> Self: + """Ensures that all arrays and dists in the batch have at least 2 dimensions. + + This is useful for ensuring that all arrays in the batch can be concatenated + along a new axis. + """ + raise ProtocolCalledException + class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" @@ -1160,7 +1185,7 @@ def __len__(self) -> int: if isinstance(obj, Distribution): lens.append(get_len_of_dist(obj)) continue - raise TypeError(f"Entry for {key} in {self} is {obj}has no len()") + raise TypeError(f"Entry for {key} in {self} is {obj} has no len()") if not lens: return 0 return min(lens) @@ -1326,6 +1351,18 @@ def replace_empty_batches_by_none(self) -> None: else: val.replace_empty_batches_by_none() + def to_at_least_2d(self) -> Self: + """Ensures that all arrays and dists in the batch have at least 2 dimensions. + + This is useful for ensuring that all arrays in the batch can be concatenated + along a new axis. + """ + result = self.apply_values_transform(np.atleast_2d, inplace=False) + for key, val in self.items(): + if isinstance(val, Distribution): + result[key] = dist_to_atleast_2d(val) + return result + def _apply_batch_values_func_recursively( batch: TBatch, From b0ba423184d3a9bd6673ead3051c6584fbc0585a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 24 Aug 2024 14:04:36 +0200 Subject: [PATCH 09/11] CollectStats: better collection for std of actions (not flattening). Added tests, renamed entries --- test/base/test_stats.py | 42 ++++++++++++++++++++++++++++ tianshou/data/collector.py | 56 ++++++++++++++++++++++++++++++-------- tianshou/data/stats.py | 25 +++++++++++++++++ 3 files changed, 111 insertions(+), 12 deletions(-) diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 9776374ba..821152e83 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -1,5 +1,12 @@ +from typing import cast + +import numpy as np import pytest +import torch +from torch.distributions import Categorical, Normal +from tianshou.data import Batch, CollectStats +from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist from tianshou.policy.base import TrainingStats, TrainingStatsWrapper @@ -47,3 +54,38 @@ def test_training_stats_wrapper() -> None: "loss_field", ), "Attribute `loss_field` not found in `wrapped_train_stats`." assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 + + @staticmethod + @pytest.mark.parametrize( + "act,dist", + ( + (np.array(1), Categorical(probs=torch.tensor([0.5, 0.5]))), + (np.array([1, 2, 3]), Normal(torch.zeros(3), torch.ones(3))), + ), + ) + def test_collect_stats_update_at_step( + act: np.ndarray, + dist: torch.distributions.Distribution, + ) -> None: + step_batch = cast( + CollectStepBatchProtocol, + Batch( + info={}, + obs=np.array([1, 2, 3]), + obs_next=np.array([4, 5, 6]), + act=act, + rew=np.array(1.0), + done=np.array(False), + terminated=np.array(False), + dist=dist, + ).to_at_least_2d(), + ) + stats = CollectStats() + for _ in range(10): + stats.update_at_step_batch(step_batch) + stats.refresh_all_sequence_stats() + assert stats.n_collected_steps == 10 + assert stats.pred_dist_std_array is not None + assert np.allclose(stats.pred_dist_std_array, get_stddev_from_dist(dist)) + assert stats.pred_dist_std_array_stat is not None + assert stats.pred_dist_std_array_stat[0].mean == get_stddev_from_dist(dist)[0].item() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2c615923d..125fe0e90 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,7 +11,7 @@ import numpy as np import torch from overrides import override -from torch.distributions import Distribution +from torch.distributions import Categorical, Distribution from tianshou.data import ( Batch, @@ -23,6 +23,7 @@ to_numpy, ) from tianshou.data.buffer.base import MalformedBufferError +from tianshou.data.stats import compute_dim_to_summary_stats from tianshou.data.types import ( ActBatchProtocol, DistBatchProtocol, @@ -75,6 +76,30 @@ class EpisodeBatchProtocol(RolloutBatchProtocol): """ +def get_stddev_from_dist(dist: Distribution) -> torch.Tensor: + """Return the standard deviation of the given distribution. + + Same as `dist.stddev` for all distributions except `Categorical`, where it is computed + by assuming that the output values 0, ..., K have the corresponding numerical meaning. + See `here `_ + for a discussion on `stddev` and `mean` of `Categorical`. + """ + if isinstance(dist, Categorical): + # torch doesn't implement stddev for Categorical, so we compute it ourselves + probs = torch.atleast_2d(dist.probs) + n_actions = probs.shape[-1] + possible_actions = torch.arange(n_actions, device=dist.probs.device).float() + + mean = torch.sum(probs * possible_actions, dim=1) + var = torch.sum(probs * (possible_actions - mean.unsqueeze(1)) ** 2, dim=1) + stddev = torch.sqrt(var) + if len(dist.batch_shape) == 0: + return stddev + return torch.atleast_2d(stddev).T + + return dist.stddev if dist is not None else torch.tensor([]) + + @dataclass(kw_only=True) class CollectStatsBase(DataclassPPrintMixin): """The most basic stats, often used for offline learning.""" @@ -115,10 +140,10 @@ class CollectStats(CollectStatsBase): """The collected episode lengths.""" lens_stat: SequenceSummaryStats | None = None """Stats of the collected episode lengths.""" - std_array: np.ndarray | None = None + pred_dist_std_array: np.ndarray | None = None """The standard deviations of the predicted distributions.""" - std_array_stat: SequenceSummaryStats | None = None - """Stats of the standard deviations of the predicted distributions.""" + pred_dist_std_array_stat: dict[int, SequenceSummaryStats] | None = None + """Stats of the standard deviations of the predicted distributions (maps action dim to stats)""" @classmethod def with_autogenerated_stats( @@ -150,12 +175,18 @@ def update_at_step_batch( refresh_sequence_stats: bool = False, ) -> None: self.n_collected_steps += len(step_batch) - action_std = step_batch.dist.stddev if step_batch.dist is not None else None - if action_std is not None: - if self.std_array is None: - self.std_array = to_numpy(action_std) + dist = step_batch.dist + action_std: torch.Tensor | None = None + + if dist is not None: + action_std = np.atleast_2d(to_numpy(get_stddev_from_dist(dist))) + + if self.pred_dist_std_array is None: + self.pred_dist_std_array = np.atleast_2d(to_numpy(action_std)) else: - self.std_array = np.concatenate((self.std_array, to_numpy(action_std))) + self.pred_dist_std_array = np.concatenate( + (self.pred_dist_std_array, np.atleast_2d(to_numpy(action_std))), + ) if refresh_sequence_stats: self.refresh_std_array_stats() @@ -208,10 +239,11 @@ def refresh_len_stats(self) -> None: self.lens_stat = None def refresh_std_array_stats(self) -> None: - if self.std_array is not None and self.std_array.size > 0: - self.std_array_stat = SequenceSummaryStats.from_sequence(self.std_array) + if self.pred_dist_std_array is not None and self.pred_dist_std_array.size > 0: + # need to use .T because action dim supposed to be the first axis in compute_dim_to_summary_stats + self.pred_dist_std_array_stat = compute_dim_to_summary_stats(self.pred_dist_std_array.T) else: - self.std_array_stat = None + self.pred_dist_std_array_stat = None def refresh_all_sequence_stats(self) -> None: self.refresh_return_stats() diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index ed64a429d..11d64c017 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -10,6 +11,8 @@ from tianshou.data import CollectStats, CollectStatsBase from tianshou.policy.base import TrainingStats +log = logging.getLogger(__name__) + @dataclass(kw_only=True) class SequenceSummaryStats(DataclassPPrintMixin): @@ -24,6 +27,14 @@ class SequenceSummaryStats(DataclassPPrintMixin): def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": if len(sequence) == 0: return cls(mean=0.0, std=0.0, max=0.0, min=0.0) + + if hasattr(sequence, "shape") and len(sequence.shape) > 1: + log.warning( + f"Sequence has shape {sequence.shape}, but only 1D sequences are supported. " + "Stats will be computed from the flattened sequence. For computing stats " + "for each dimension consider using the function `compute_dim_to_summary_stats`.", + ) + return cls( mean=float(np.mean(sequence)), std=float(np.std(sequence)), @@ -32,6 +43,20 @@ def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "Sequenc ) +def compute_dim_to_summary_stats( + arr: Sequence[Sequence[float]] | np.ndarray, +) -> dict[int, SequenceSummaryStats]: + """Compute summary statistics for each dimension of a sequence. + + :param arr: a 2-dim arr (or sequence of sequences) from which to compute the statistics. + :return: A dictionary of summary statistics for each dimension. + """ + stats = {} + for dim, seq in enumerate(arr): + stats[dim] = SequenceSummaryStats.from_sequence(seq) + return stats + + @dataclass(kw_only=True) class TimingStats(DataclassPPrintMixin): """A data structure for storing timing statistics.""" From b3d78d2f6c1182a70cdf8a7699c1ea403a90db08 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 26 Aug 2024 18:28:09 +0200 Subject: [PATCH 10/11] Minor fix in tensorboard logger (casting key to str) --- tianshou/utils/logger/tensorboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 1406cfbb8..ef504cb58 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -71,7 +71,7 @@ def add_to_result( if exclude_arrays and isinstance(value, np.ndarray): continue - new_key = prefix + delimiter + key + new_key = prefix + delimiter + str(key) new_key = new_key.lstrip(delimiter) if isinstance(value, dict): From 4e031917046fea4ed78628612afdaf396fdba872 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 26 Aug 2024 18:49:47 +0200 Subject: [PATCH 11/11] Spelling --- docs/spelling_wordlist.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 25eaa526c..b49094d03 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -288,3 +288,5 @@ monte carlo subclass subclassing +dist +dists