diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 4d26905c3..8911bab16 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -4,7 +4,9 @@ import torch from torch.distributions import Categorical, Distribution, Independent, Normal +from tianshou.data import Batch from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.policy.base import RandomActionPolicy from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -77,3 +79,45 @@ def test_get_action(self, policy: PPOPolicy) -> None: actions = [policy.compute_action(sample_obs) for _ in range(10)] # check that the actions are the same in deterministic mode assert len(set(map(_to_hashable, actions))) == 1 + + @staticmethod + def test_random_policy_discrete_actions() -> None: + action_space = gym.spaces.Discrete(3) + policy = RandomActionPolicy(action_space=action_space) + + # forward of actor returns discrete probabilities, in compliance with the overall discrete actor + action_probs = policy.actor(np.zeros((10, 2)))[0] + assert np.allclose(action_probs, 1 / 3 * np.ones((10, 3))) + + actions = [] + for _ in range(10): + action = policy.compute_action(np.array([0])) + assert action_space.contains(action) + actions.append(action) + + # not all actions are the same + assert len(set(actions)) > 1 + + # test batched forward + action_batch = policy(Batch(obs=np.zeros((10, 2)))) + assert action_batch.act.shape == (10,) + assert len(set(action_batch.act.tolist())) > 1 + + @staticmethod + def test_random_policy_continuous_actions() -> None: + action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) + policy = RandomActionPolicy(action_space=action_space) + + actions = [] + for _ in range(10): + action = policy.compute_action(np.array([0])) + assert action_space.contains(action) + actions.append(action) + + # not all actions are the same + assert len(set(map(_to_hashable, actions))) > 1 + + # test batched forward + action_batch = policy(Batch(obs=np.zeros((10, 2)))) + assert action_batch.act.shape == (10, 3) + assert len(set(map(_to_hashable, action_batch.act))) > 1 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index fb3c71e0f..d21010bf0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -712,7 +712,7 @@ def forward( state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActStateBatchProtocol: - act, next_state = self.actor(batch.obs, state) + act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7510ae7e7..243a04093 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -7,7 +7,7 @@ from gymnasium import spaces from torch import nn -from tianshou.data.batch import Batch +from tianshou.data.batch import Batch, BatchProtocol from tianshou.data.types import RecurrentStateBatch from tianshou.utils.space_info import ActionSpaceInfo @@ -661,9 +661,13 @@ def get_preprocess_net(self) -> nn.Module: def get_output_dim(self) -> int: return self.space_info.action_dim + @property + def is_discrete(self) -> bool: + return isinstance(self.action_space, spaces.Discrete) + def forward( self, - obs: np.ndarray | torch.Tensor, + obs: np.ndarray | torch.Tensor | BatchProtocol, state: Any | None = None, info: dict[str, Any] | None = None, ) -> tuple[np.ndarray, Any | None]: @@ -675,6 +679,14 @@ def forward( action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) return action, state + def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray: + if self.is_discrete: + # Different from forward which returns discrete probabilities, see comment there + assert isinstance(self.action_space, spaces.Discrete) # for mypy + return np.random.randint(low=0, high=self.action_space.n, size=len(obs)) + else: + return self.forward(obs)[0] + def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: """Gets the given attribute from the given object or takes the alternative value if it is not present.