From 198669224285a60a6431f537067bad5e10b52702 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 21 May 2024 10:01:03 +0100 Subject: [PATCH] Update Gymnasium to 1.0.0a1 and remove Atari in favour of ale-py new gymnasium implementation (#117) --- .github/dependabot.yml | 14 + .github/workflows/build-publish.yml | 19 +- .github/workflows/optional-test-atari.yml | 19 - .../workflows/optional-test-meltingpot.yml | 2 +- bin/atari.Dockerfile | 37 -- setup.py | 15 +- shimmy/__init__.py | 6 +- shimmy/atari_env.py | 429 ------------------ shimmy/dm_control_multiagent_compatibility.py | 1 + shimmy/meltingpot_compatibility.py | 2 - shimmy/openai_gym_compatibility.py | 1 - shimmy/registration.py | 126 +---- shimmy/utils/envs_configs.py | 171 ------- tests/test_atari.py | 110 ----- tests/test_bsuite.py | 6 +- tests/test_dm_control.py | 6 +- tests/test_dm_lab.py | 7 +- tests/test_gym.py | 5 +- tests/test_meltingpot.py | 1 - 19 files changed, 50 insertions(+), 927 deletions(-) create mode 100644 .github/dependabot.yml delete mode 100644 .github/workflows/optional-test-atari.yml delete mode 100644 bin/atari.Dockerfile delete mode 100644 shimmy/atari_env.py delete mode 100644 tests/test_atari.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..ea7c7d78 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,14 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + # Enable version updates for GitHub Actions + - package-ecosystem: "github-actions" + # Look for GitHub Actions workflows in the `root` directory + directory: "/" + # Check the for updates once a week + schedule: + interval: "weekly" diff --git a/.github/workflows/build-publish.yml b/.github/workflows/build-publish.yml index 737c653e..79673d15 100644 --- a/.github/workflows/build-publish.yml +++ b/.github/workflows/build-publish.yml @@ -16,29 +16,14 @@ on: jobs: build-wheels: - runs-on: ${{ matrix.os }} - strategy: - matrix: - include: - - os: ubuntu-latest - python: 37 - platform: manylinux_x86_64 - - os: ubuntu-latest - python: 38 - platform: manylinux_x86_64 - - os: ubuntu-latest - python: 39 - platform: manylinux_x86_64 - - os: ubuntu-latest - python: 310 - platform: manylinux_x86_64 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.x' + python-version: '3.7' - name: Install dependencies run: python -m pip install --upgrade setuptools wheel build - name: Build wheels diff --git a/.github/workflows/optional-test-atari.yml b/.github/workflows/optional-test-atari.yml deleted file mode 100644 index e11c2d9f..00000000 --- a/.github/workflows/optional-test-atari.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: build -on: [pull_request, push] - -permissions: - contents: read # to fetch code (actions/checkout) - -jobs: - optional-test-atari: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - # Atari - - run: | - docker build -f bin/atari.Dockerfile \ - --build-arg PYTHON_VERSION='3.10' \ - --tag shimmy-atari-docker . - - name: Run atari tests - run: docker run shimmy-atari-docker pytest tests/test_atari.py diff --git a/.github/workflows/optional-test-meltingpot.yml b/.github/workflows/optional-test-meltingpot.yml index 23107e63..79f7b673 100644 --- a/.github/workflows/optional-test-meltingpot.yml +++ b/.github/workflows/optional-test-meltingpot.yml @@ -13,7 +13,7 @@ jobs: # Melting Pot - run: | docker build -f bin/meltingpot.Dockerfile \ - --build-arg PYTHON_VERSION='3.10' \ + --build-arg PYTHON_VERSION='3.11' \ --tag shimmy-meltingpot-docker . - name: Run meltingpot tests run: docker run shimmy-meltingpot-docker pytest tests/test_meltingpot.py diff --git a/bin/atari.Dockerfile b/bin/atari.Dockerfile deleted file mode 100644 index 06827560..00000000 --- a/bin/atari.Dockerfile +++ /dev/null @@ -1,37 +0,0 @@ -# A Dockerfile that sets up a full shimmy install with test dependencies - -# if PYTHON_VERSION is not specified as a build argument, set it to 3.10. -ARG PYTHON_VERSION -ARG PYTHON_VERSION=${PYTHON_VERSION:-3.10} -FROM python:$PYTHON_VERSION - -SHELL ["/bin/bash", "-o", "pipefail", "-c"] - -RUN pip install --upgrade pip - -# Install Shimmy requirements -RUN apt-get -y update \ - && apt-get install --no-install-recommends -y \ - unzip \ - libglu1-mesa-dev \ - libgl1-mesa-dev \ - libosmesa6-dev \ - xvfb \ - patchelf \ - ffmpeg cmake \ - && apt-get autoremove -y \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* - -COPY . /usr/local/shimmy/ -WORKDIR /usr/local/shimmy/ - -# Install Shimmy -RUN if [ -f "pyproject.toml" ]; then \ - pip install ".[atari, testing]" --no-cache-dir; \ - else \ - pip install -U "shimmy[atari, testing] @ git+https://github.com/Farama-Foundation/Shimmy.git" --no-cache-dir; \ - mkdir -p bin && mv docker_entrypoint bin/docker_entrypoint; \ - fi - -ENTRYPOINT ["/usr/local/shimmy/bin/docker_entrypoint"] diff --git a/setup.py b/setup.py index 7bde0df9..72da53cd 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,6 @@ def get_version(): extras = { "gym-v21": ["gym>=0.21.0,<0.26", "pyglet==1.5.11"], "gym-v26": ["gym>=0.26.2"], - "atari": ["ale-py~=0.8.1"], # "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts "dm-control": ["dm-control>=1.0.10", "imageio", "h5py>=3.7.0"], "dm-control-multi-agent": [ @@ -46,16 +45,18 @@ def get_version(): ], "dm-lab": ["dm-env>=1.6"], "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.23"], - "meltingpot": ["pettingzoo>=1.23", "dm-meltingpot>=2.2.0; python_version > '3.9'"], + "meltingpot": [ + "pettingzoo>=1.23", + "dm-meltingpot>=2.2.2; python_version > '3.10'", + ], "bsuite": ["bsuite>=0.3.5"], } extras["all"] = [ lib for key, libs in extras.items() if key != "gym-v21" for lib in libs ] extras["testing"] = [ - "pytest==7.1.3", + "pytest>=7.1.3", "pillow>=9.3.0", - "autorom[accept-rom-license]~=0.6.0", ] setup( @@ -71,7 +72,7 @@ def get_version(): keywords=["Reinforcement Learning", "game", "RL", "AI"], python_requires=">=3.8", packages=find_packages(), - install_requires=["numpy>=1.18.0", "gymnasium>=0.27.0"], + install_requires=["numpy>=1.18.0", "gymnasium>=1.0.0a1"], tests_require=extras["testing"], extras_require=extras, classifiers=[ @@ -79,11 +80,9 @@ def get_version(): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], include_package_data=True, - entry_points={ - "gymnasium.envs": ["__root__ = shimmy.registration:register_gymnasium_envs"] - }, ) diff --git a/shimmy/__init__.py b/shimmy/__init__.py index b00d4486..d7a0fed8 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -5,6 +5,10 @@ from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 from shimmy.openai_gym_compatibility import GymV21CompatibilityV0, GymV26CompatibilityV0 +from shimmy.registration import register_gymnasium_envs + +# this registers the environments on `import shimmy` +register_gymnasium_envs() class NotInstallClass: @@ -73,7 +77,7 @@ def __call__(self, *args: list[Any], **kwargs: Any): ] -__version__ = "1.3.0" +__version__ = "2.0.0" try: diff --git a/shimmy/atari_env.py b/shimmy/atari_env.py deleted file mode 100644 index a7ddef51..00000000 --- a/shimmy/atari_env.py +++ /dev/null @@ -1,429 +0,0 @@ -"""ALE-py interface for atari. - -This file was originally copied from https://github.com/mgbellemare/Arcade-Learning-Environment/blob/master/src/python/env/gym.py -Under the GNU General Public License v2.0 - -Copyright is held by the authors - -Changes -* Added `self.render_mode` which is identical to `self._render_mode` -""" -from __future__ import annotations - -import sys -from typing import Any, Sequence - -import ale_py -import ale_py.roms as roms -import ale_py.roms.utils as rom_utils -import gymnasium -import gymnasium.logger as logger -import numpy as np -from gymnasium.error import Error -from gymnasium.spaces import Box, Discrete -from gymnasium.utils.ezpickle import EzPickle - -if sys.version_info < (3, 11): - from typing_extensions import NotRequired, TypedDict -else: - from typing import NotRequired, TypedDict - - -class AtariEnvStepMetadata(TypedDict): - """Atari Environment Step Metadata.""" - - lives: int - episode_frame_number: int - frame_number: int - seeds: NotRequired[Sequence[int]] - - -class AtariEnv(gymnasium.Env[np.ndarray, np.int64], EzPickle): - """(A)rcade (L)earning (Gymnasium) (Env)ironment. - - A Gymnasium wrapper around the Arcade Learning Environment (ALE). - """ - - # No render modes - metadata = { - "render_modes": ["human", "rgb_array"], - "obs_types": {"rgb", "grayscale", "ram"}, - } - - def __init__( - self, - game: str = "pong", - mode: int | None = None, - difficulty: int | None = None, - obs_type: str = "rgb", - frameskip: tuple[int, int] | int = 4, - repeat_action_probability: float = 0.25, - full_action_space: bool = False, - max_num_frames_per_episode: int | None = None, - render_mode: str | None = None, - ): - """Initialize the ALE interface for Gymnasium. - - Default parameters are taken from Machado et al., 2018. - - Args: - game: str => Game to initialize env with. - mode: Optional[int] => Game mode, see Machado et al., 2018 - difficulty: Optional[int] => Game difficulty,see Machado et al., 2018 - obs_type: str => Observation type in { 'rgb', 'grayscale', 'ram' } - frameskip: Union[Tuple[int, int], int] => - Stochastic frameskip as tuple or fixed. - repeat_action_probability: int => - Probability to repeat actions, see Machado et al., 2018 - full_action_space: bool => Use full action space? - max_num_frames_per_episode: int => Max number of frame per episode. - Once `max_num_frames_per_episode` is reached the episode is - truncated. - render_mode: str => One of { 'human', 'rgb_array' }. - If `human` we'll interactively display the screen and enable - game sounds. This will lock emulation to the ROMs specified FPS - If `rgb_array` we'll return the `rgb` key in step metadata with - the current environment RGB frame. - - Note: - - The game must be installed, see ale-import-roms, or ale-py-roms. - - Frameskip values of (low, high) will enable stochastic frame skip - which will sample a random frameskip uniformly each action. - - It is recommended to enable full action space. - See Machado et al., 2018 for more details. - - References: - `Revisiting the Arcade Learning Environment: Evaluation Protocols - and Open Problems for General Agents`, Machado et al., 2018, JAIR - URL: https://jair.org/index.php/jair/article/view/11182 - """ - if obs_type == "image": - logger.warn( - 'obs_type "image" should be replaced with the image type, one of: rgb, grayscale' - ) - obs_type = "rgb" - if obs_type not in self.metadata["obs_types"]: - raise Error( - f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram." - ) - - if type(frameskip) not in (int, tuple): - raise Error(f"Invalid frameskip type: {type(frameskip)}.") - if isinstance(frameskip, int) and frameskip <= 0: - raise Error( - f"Invalid frameskip of {frameskip}, frameskip must be positive." - ) - elif isinstance(frameskip, tuple) and len(frameskip) != 2: - raise Error( - f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2." - ) - elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]: - raise Error( - "Invalid stochastic frameskip, lower bound is greater than upper bound." - ) - elif isinstance(frameskip, tuple) and frameskip[0] <= 0: - raise Error( - "Invalid stochastic frameskip lower bound is greater than upper bound." - ) - - if render_mode is not None and render_mode not in self.metadata["render_modes"]: - raise Error(f"Render mode {render_mode} not supported (rgb_array, human).") - - EzPickle.__init__( - self, - game, - mode, - difficulty, - obs_type, - frameskip, - repeat_action_probability, - full_action_space, - max_num_frames_per_episode, - render_mode, - ) - - # Initialize ALE - self.ale = ale_py.ALEInterface() - - self._game = rom_utils.rom_id_to_name(game) - - self._game_mode = mode - self._game_difficulty = difficulty - - self._frameskip = frameskip - self._obs_type = obs_type - self._render_mode = self.render_mode = render_mode - - # Set logger mode to error only - self.ale.setLoggerMode(ale_py.LoggerMode.Error) - # Config sticky action prob. - self.ale.setFloat("repeat_action_probability", repeat_action_probability) - - if max_num_frames_per_episode is not None: - self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) - - # If render mode is human we can display screen and sound - if render_mode == "human": - self.ale.setBool("display_screen", True) - self.ale.setBool("sound", True) - - # Seed + Load - self.seed() - - if full_action_space: - self._action_set = self.ale.getLegalActionSet() - else: - self._action_set = self.ale.getMinimalActionSet() - self.action_space = Discrete(len(self._action_set)) - - # Initialize observation type - if self._obs_type == "ram": - self.observation_space = Box( - low=0, high=255, dtype=np.uint8, shape=(self.ale.getRAMSize(),) - ) - elif self._obs_type == "rgb" or self._obs_type == "grayscale": - image_shape = self.ale.getScreenDims() - if self._obs_type == "rgb": - image_shape += (3,) - self.observation_space = Box( - low=0, high=255, dtype=np.uint8, shape=image_shape - ) - else: - raise Error(f"Unrecognized observation type: {self._obs_type}") - - def seed(self, seed: int | None = None) -> tuple[int, int]: - """Seeds both the internal numpy rng for stochastic frame skip and the ALE RNG. - - This function must also initialize the ROM and set the corresponding - mode and difficulty. `seed` may be called to initialize the environment - during deserialization by Gymnasium so these side-effects must reside here. - - Args: - seed: int => Manually set the seed for RNG. - - Returns: - tuple[int, int] => (np seed, ALE seed) - """ - ss = np.random.SeedSequence(seed) - seed1, seed2 = ss.generate_state(n_words=2) - - self.np_random = np.random.default_rng(seed1) - # ALE only takes signed integers for `setInt`, it'll get converted back - # to unsigned in StellaEnvironment. - self.ale.setInt("random_seed", seed2.astype(np.int32)) - - if not hasattr(roms, self._game): - raise Error( - f'We\'re Unable to find the game "{self._game}". Note: Gymnasium no longer distributes ROMs. ' - f"If you own a license to use the necessary ROMs for research purposes you can download them " - f'via `pip install gymnasium[accept-rom-license]`. Otherwise, you should try importing "{self._game}" ' - f'via the command `ale-import-roms`. If you believe this is a mistake perhaps your copy of "{self._game}" ' - "is unsupported. To check if this is the case try providing the environment variable " - "`PYTHONWARNINGS=default::ImportWarning:ale_py.roms`. For more information see: " - "https://github.com/mgbellemare/Arcade-Learning-Environment#rom-management" - ) - self.ale.loadROM(getattr(roms, self._game)) - - if self._game_mode is not None: - self.ale.setMode(self._game_mode) - if self._game_difficulty is not None: - self.ale.setDifficulty(self._game_difficulty) - - return seed1, seed2 - - def reset( - self, - *, - seed: int | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[np.ndarray, AtariEnvStepMetadata]: - """Resets environment and returns initial observation. - - Args: - seed: The reset seed - options: The reset options - - Returns: - The reset observation and info - """ - super().reset(seed=seed, options=options) - del options - # Gymnasium's new seeding API seeds on reset. - # This will cause the console to be recreated - # and loose all previous state, e.g., statistics, etc. - seeded_with = None - if seed is not None: - seeded_with = self.seed(seed) - - self.ale.reset_game() - obs = self._get_obs() - - info = self._get_info() - if seeded_with is not None: - info["seeds"] = seeded_with - return obs, info - - def step( - self, - action_ind: int, - ) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: - """Perform one agent step, i.e., repeats `action` frameskip # of steps. - - Args: - action_ind: int => Action index to execute - Returns: - Tuple[np.ndarray, float, bool, Dict[str, Any]] => observation, reward, terminal, metadata - - Note: `metadata` contains the keys "lives" and "rgb" if render_mode == 'rgb_array'. - """ - # Get action enum, terminal bool, metadata - action = self._action_set[action_ind] - - # If frameskip is a length 2 tuple then it's stochastic - # frameskip between [frameskip[0], frameskip[1]] uniformly. - if isinstance(self._frameskip, int): - frameskip = self._frameskip - elif isinstance(self._frameskip, tuple): - frameskip = self.np_random.integers(*self._frameskip) - else: - raise Error(f"Invalid frameskip type: {self._frameskip}") - - # Frameskip - reward = 0.0 - for _ in range(frameskip): - reward += self.ale.act(action) - is_terminal = self.ale.game_over(with_truncation=False) - is_truncated = self.ale.game_truncated() - - return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() - - def render(self) -> Any: - """Renders the ALE environment. - - Returns: - If render_mode is "rgb_array", returns the screen RGB view. - """ - if self.render_mode == "rgb_array": - return self.ale.getScreenRGB() - elif self.render_mode == "human": - pass - else: - raise Error( - f"Invalid render mode `{self.render_mode}`. Supported modes: `human`, `rgb_array`." - ) - - def _get_obs(self) -> np.ndarray: - """Retrieves the current observation, dependent on `self._obs_type`. - - Returns: - The current observation - """ - if self._obs_type == "ram": - return self.ale.getRAM() - elif self._obs_type == "rgb": - return self.ale.getScreenRGB() - elif self._obs_type == "grayscale": - return self.ale.getScreenGrayscale() - else: - raise Error(f"Unrecognized observation type: {self._obs_type}") - - def _get_info(self) -> AtariEnvStepMetadata: - return { - "lives": self.ale.lives(), - "episode_frame_number": self.ale.getEpisodeFrameNumber(), - "frame_number": self.ale.getFrameNumber(), - } - - def get_keys_to_action(self) -> dict[tuple[int], ale_py.Action]: - """Return keymapping -> actions for human play. - - Returns: - A dictionary of keys to actions. - """ - UP = ord("w") - LEFT = ord("a") - RIGHT = ord("d") - DOWN = ord("s") - FIRE = ord(" ") - - mapping = { - ale_py.Action.NOOP: (None,), - ale_py.Action.UP: (UP,), - ale_py.Action.FIRE: (FIRE,), - ale_py.Action.DOWN: (DOWN,), - ale_py.Action.LEFT: (LEFT,), - ale_py.Action.RIGHT: (RIGHT,), - ale_py.Action.UPFIRE: (UP, FIRE), - ale_py.Action.DOWNFIRE: (DOWN, FIRE), - ale_py.Action.LEFTFIRE: (LEFT, FIRE), - ale_py.Action.RIGHTFIRE: (RIGHT, FIRE), - ale_py.Action.UPLEFT: (UP, LEFT), - ale_py.Action.UPRIGHT: (UP, RIGHT), - ale_py.Action.DOWNLEFT: (DOWN, LEFT), - ale_py.Action.DOWNRIGHT: (DOWN, RIGHT), - ale_py.Action.UPLEFTFIRE: (UP, LEFT, FIRE), - ale_py.Action.UPRIGHTFIRE: (UP, RIGHT, FIRE), - ale_py.Action.DOWNLEFTFIRE: (DOWN, LEFT, FIRE), - ale_py.Action.DOWNRIGHTFIRE: (DOWN, RIGHT, FIRE), - } - - # Map - # (key, key, ...) -> action_idx - # where action_idx is the integer value of the action enum - # - actions = self._action_set - return dict( - zip( - map(lambda action: tuple(sorted(mapping[action])), actions), - range(len(actions)), - ) - ) - - def get_action_meanings(self) -> list[str]: - """Return the meaning of each integer action. - - Returns: - A list of action meaning - """ - keys = ale_py.Action.__members__.values() - values = ale_py.Action.__members__.keys() - mapping = dict(zip(keys, values)) - return [mapping[action] for action in self._action_set] - - def clone_state(self, include_rng: bool = False) -> ale_py.ALEState: - """Clone emulator state w/o system state. - - Restoring this state will *not* give an identical environment. - For complete cloning and restoring of the full state, see `{clone,restore}_full_state()`. - - Args: - include_rng: If to include the rng in the cloned state - - Returns: - The cloned state - """ - return self.ale.cloneState(include_rng=include_rng) - - def restore_state(self, state: ale_py.ALEState): - """Restore emulator state w/o system state. - - Args: - state: The state to restore - """ - self.ale.restoreState(state) - - def clone_full_state(self) -> ale_py.ALEState: - """Deprecated method which would clone the emulator and system state.""" - logger.warn( - "`clone_full_state()` is deprecated and will be removed in a future release of `ale-py`. " - "Please use `clone_state(include_rng=True)` which is equivalent to `clone_full_state`. " - ) - return self.ale.cloneSystemState() - - def restore_full_state(self, state: ale_py.ALEState): - """Restore emulator state w/ system state including pseudo-randomness.""" - logger.warn( - "restore_full_state() is deprecated and will be removed in a future release of `ale-py`. " - "Please use `restore_state(state)` which will restore the state regardless of being a full or partial state. " - ) - self.ale.restoreSystemState(state) diff --git a/shimmy/dm_control_multiagent_compatibility.py b/shimmy/dm_control_multiagent_compatibility.py index e5189590..5a61b3ce 100644 --- a/shimmy/dm_control_multiagent_compatibility.py +++ b/shimmy/dm_control_multiagent_compatibility.py @@ -240,6 +240,7 @@ def reset( if self.render_mode == "human": self.viewer.close() + assert self._env.physics is not None self.viewer = MujocoRenderer( self._env.physics.model.ptr, self._env.physics.data.ptr ) diff --git a/shimmy/meltingpot_compatibility.py b/shimmy/meltingpot_compatibility.py index e2204a1e..91183b4d 100644 --- a/shimmy/meltingpot_compatibility.py +++ b/shimmy/meltingpot_compatibility.py @@ -8,7 +8,6 @@ from __future__ import annotations import functools -from itertools import repeat from typing import TYPE_CHECKING, Any, Optional import dm_env @@ -21,7 +20,6 @@ import shimmy.utils.meltingpot as utils if TYPE_CHECKING: - import meltingpot from meltingpot.utils.substrates import substrate diff --git a/shimmy/openai_gym_compatibility.py b/shimmy/openai_gym_compatibility.py index 3758adf2..cadbd6eb 100644 --- a/shimmy/openai_gym_compatibility.py +++ b/shimmy/openai_gym_compatibility.py @@ -2,7 +2,6 @@ # pyright: reportGeneralTypeIssues=false, reportPrivateImportUsage=false from __future__ import annotations -import sys from typing import Any, Protocol, runtime_checkable import gymnasium diff --git a/shimmy/registration.py b/shimmy/registration.py index a6b3e31f..492ddbd4 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -1,19 +1,16 @@ """Registers environments within gymnasium for optional modules.""" from __future__ import annotations -from collections import defaultdict from functools import partial -from typing import Any, Callable, Mapping, NamedTuple, Sequence +from typing import Any, Callable, Mapping import numpy as np from gymnasium.envs.registration import register, registry from shimmy.utils.envs_configs import ( - ALL_ATARI_GAMES, BSUITE_ENVS, DM_CONTROL_MANIPULATION_ENVS, DM_CONTROL_SUITE_ENVS, - LEGACY_ATARI_GAMES, ) @@ -144,126 +141,6 @@ def _make_dm_control_manipulation_env(env_name: str, **render_kwargs): ) -class GymFlavour(NamedTuple): - """A Gym Flavour.""" - - suffix: str - kwargs: Mapping[str, Any] | Callable[[str], Mapping[str, Any]] - - -class GymConfig(NamedTuple): - """A Gym Configuration.""" - - version: str - kwargs: Mapping[str, Any] - flavours: Sequence[GymFlavour] - - -def _register_atari_configs( - roms: Sequence[str], - obs_types: Sequence[str], - configs: Sequence[GymConfig], - prefix: str = "", -): - from ale_py.roms import utils as rom_utils - - for rom in roms: - for obs_type in obs_types: - for config in configs: - for flavour in config.flavours: - name = rom_utils.rom_id_to_name(rom) - if obs_type == "ram": - name = f"{name}-ram" - - # Parse config kwargs - if callable(config.kwargs): - config_kwargs = config.kwargs(rom) - else: - config_kwargs = config.kwargs - - # Parse flavour kwargs - if callable(flavour.kwargs): - flavour_kwargs = flavour.kwargs(rom) - else: - flavour_kwargs = flavour.kwargs - - # Register the environment - register( - id=f"{prefix}{name}{flavour.suffix}-{config.version}", - entry_point="shimmy.atari_env:AtariEnv", - kwargs={ - "game": rom, - "obs_type": obs_type, - **config_kwargs, - **flavour_kwargs, - }, - ) - - -def _register_atari_envs(): - try: - import ale_py - except ImportError: - return - - frameskip: dict[str, int] = defaultdict(lambda: 4, [("space_invaders", 3)]) - - configs = [ - GymConfig( - version="v0", - kwargs={ - "repeat_action_probability": 0.25, - "full_action_space": False, - "max_num_frames_per_episode": 108_000, - }, - flavours=[ - # Default for v0 has 10k steps, no idea why... - GymFlavour("", {"frameskip": (2, 5)}), - # Deterministic has 100k steps, close to the standard of 108k (30 mins gameplay) - GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}), - # NoFrameSkip imposes a max episode steps of frameskip * 100k, weird... - GymFlavour("NoFrameskip", {"frameskip": 1}), - ], - ), - GymConfig( - version="v4", - kwargs={ - "repeat_action_probability": 0.0, - "full_action_space": False, - "max_num_frames_per_episode": 108_000, - }, - flavours=[ - # Unlike v0, v4 has 100k max episode steps - GymFlavour("", {"frameskip": (2, 5)}), - GymFlavour("Deterministic", lambda rom: {"frameskip": frameskip[rom]}), - # Same weird frameskip * 100k max steps for v4? - GymFlavour("NoFrameskip", {"frameskip": 1}), - ], - ), - ] - _register_atari_configs( - LEGACY_ATARI_GAMES, obs_types=("rgb", "ram"), configs=configs - ) - - # max_episode_steps is 108k frames which is 30 mins of gameplay. - # This corresponds to 108k / 4 = 27,000 steps - configs = [ - GymConfig( - version="v5", - kwargs={ - "repeat_action_probability": 0.25, - "full_action_space": False, - "frameskip": 4, - "max_num_frames_per_episode": 108_000, - }, - flavours=[GymFlavour("", {})], - ) - ] - _register_atari_configs( - ALL_ATARI_GAMES, obs_types=("rgb", "ram"), configs=configs, prefix="ALE/" - ) - - def _register_dm_lab(): try: import deepmind_lab @@ -301,5 +178,4 @@ def register_gymnasium_envs(): _register_bsuite_envs() _register_dm_control_envs() - _register_atari_envs() _register_dm_lab() diff --git a/shimmy/utils/envs_configs.py b/shimmy/utils/envs_configs.py index 76ecf217..c6857135 100644 --- a/shimmy/utils/envs_configs.py +++ b/shimmy/utils/envs_configs.py @@ -108,174 +108,3 @@ "reach_site_features", "reach_site_vision", ) - -ALL_ATARI_GAMES = ( - "adventure", - "air_raid", - "alien", - "amidar", - "assault", - "asterix", - "asteroids", - "atlantis", - "atlantis2", - "backgammon", - "bank_heist", - "basic_math", - "battle_zone", - "beam_rider", - "berzerk", - "blackjack", - "bowling", - "boxing", - "breakout", - "carnival", - "casino", - "centipede", - "chopper_command", - "crazy_climber", - "crossbow", - "darkchambers", - "defender", - "demon_attack", - "donkey_kong", - "double_dunk", - "earthworld", - "elevator_action", - "enduro", - "entombed", - "et", - "fishing_derby", - "flag_capture", - "freeway", - "frogger", - "frostbite", - "galaxian", - "gopher", - "gravitar", - "hangman", - "haunted_house", - "hero", - "human_cannonball", - "ice_hockey", - "jamesbond", - "journey_escape", - "kaboom", - "kangaroo", - "keystone_kapers", - "king_kong", - "klax", - "koolaid", - "krull", - "kung_fu_master", - "laser_gates", - "lost_luggage", - "mario_bros", - "miniature_golf", - "montezuma_revenge", - "mr_do", - "ms_pacman", - "name_this_game", - "othello", - "pacman", - "phoenix", - "pitfall", - "pitfall2", - "pong", - "pooyan", - "private_eye", - "qbert", - "riverraid", - "road_runner", - "robotank", - "seaquest", - "sir_lancelot", - "skiing", - "solaris", - "space_invaders", - "space_war", - "star_gunner", - "superman", - "surround", - "tennis", - "tetris", - "tic_tac_toe_3d", - "time_pilot", - "trondead", - "turmoil", - "tutankham", - "up_n_down", - "venture", - "video_checkers", - "video_chess", - "video_cube", - "video_pinball", - "wizard_of_wor", - "word_zapper", - "yars_revenge", - "zaxxon", -) -LEGACY_ATARI_GAMES = ( - "adventure", - "air_raid", - "alien", - "amidar", - "assault", - "asterix", - "asteroids", - "atlantis", - "bank_heist", - "battle_zone", - "beam_rider", - "berzerk", - "bowling", - "boxing", - "breakout", - "carnival", - "centipede", - "chopper_command", - "crazy_climber", - "defender", - "demon_attack", - "double_dunk", - "elevator_action", - "enduro", - "fishing_derby", - "freeway", - "frostbite", - "gopher", - "gravitar", - "hero", - "ice_hockey", - "jamesbond", - "journey_escape", - "kangaroo", - "krull", - "kung_fu_master", - "montezuma_revenge", - "ms_pacman", - "name_this_game", - "phoenix", - "pitfall", - "pong", - "pooyan", - "private_eye", - "qbert", - "riverraid", - "road_runner", - "robotank", - "seaquest", - "skiing", - "solaris", - "space_invaders", - "star_gunner", - "tennis", - "time_pilot", - "tutankham", - "up_n_down", - "venture", - "video_pinball", - "wizard_of_wor", - "yars_revenge", - "zaxxon", -) diff --git a/tests/test_atari.py b/tests/test_atari.py deleted file mode 100644 index e647fd40..00000000 --- a/tests/test_atari.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests the ale-py environments are correctly registered.""" -import pickle -import warnings - -import gymnasium as gym -import pytest -from ale_py import roms -from ale_py.roms import utils as rom_utils -from gymnasium.envs.registration import registry -from gymnasium.error import Error -from gymnasium.utils.env_checker import check_env, data_equivalence - -from shimmy.utils.envs_configs import ALL_ATARI_GAMES - - -def test_all_atari_roms(): - """Tests that the static variable ALL_ATARI_GAME is equal to all actual roms.""" - assert ALL_ATARI_GAMES == tuple(map(rom_utils.rom_name_to_id, dir(roms))) - - -CHECK_ENV_IGNORE_WARNINGS = [ - f"\x1b[33mWARN: {message}\x1b[0m" - for message in [ - "Official support for the `seed` function is dropped. Standard practice is to reset gymnasium environments using `env.reset(seed=)`", - "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.", - ] -] - - -@pytest.mark.parametrize( - "env_id", - [ - env_id - for env_id, env_spec in registry.items() - if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv" - ], -) -def test_atari_envs(env_id): - """Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants.""" - env = gym.make(env_id) - - with warnings.catch_warnings(record=True) as caught_warnings: - check_env(env.unwrapped) - - env.close() - - for warning_message in caught_warnings: - assert isinstance(warning_message.message, Warning) - if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: - raise Error(f"Unexpected warning: {warning_message.message}") - - -@pytest.mark.parametrize( - "env_id", - [ - env_id - for env_id, env_spec in registry.items() - if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv" - ], -) -def test_atari_pickle(env_id): - """Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants.""" - env_1 = gym.make(env_id) - env_2 = pickle.loads(pickle.dumps(env_1)) - - obs_1, info_1 = env_1.reset(seed=42) - obs_2, info_2 = env_2.reset(seed=42) - assert data_equivalence(obs_1, obs_2) - assert data_equivalence(info_1, info_2) - for _ in range(100): - actions = int(env_1.action_space.sample()) - obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) - obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) - assert data_equivalence(obs_1, obs_2) - assert reward_1 == reward_2 - assert term_1 == term_2 and trunc_1 == trunc_2 - assert data_equivalence(info_1, info_2) - - env_1.close() - env_2.close() - - -@pytest.mark.parametrize( - "env_id", - [ - env_id - for env_id, env_spec in registry.items() - if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv" - ], -) -def test_atari_seeding(env_id): - """Tests the seeding of the atari conversion wrapper.""" - env_1 = gym.make(env_id) - env_2 = gym.make(env_id) - - obs_1, info_1 = env_1.reset(seed=42) - obs_2, info_2 = env_2.reset(seed=42) - assert data_equivalence(obs_1, obs_2) - assert data_equivalence(info_1, info_2) - for _ in range(100): - actions = int(env_1.action_space.sample()) - obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) - obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) - assert data_equivalence(obs_1, obs_2) - assert reward_1 == reward_2 - assert term_1 == term_2 and trunc_1 == trunc_2 - assert data_equivalence(info_1, info_2) - - env_1.close() - env_2.close() diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index d8be2582..0d76c3ea 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -9,6 +9,10 @@ from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence +import shimmy + +gym.register_envs(shimmy) + BSUITE_ENV_IDS = [ env_id for env_id in registry @@ -56,7 +60,7 @@ def test_bsuite_suite_envs(): f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "A Box observation space minimum value is -infinity. This is probably too low.", - "A Box observation space maximum value is -infinity. This is probably too high.", + "A Box observation space maximum value is infinity. This is probably too high.", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)", diff --git a/tests/test_dm_control.py b/tests/test_dm_control.py index 8dfc599f..d1122455 100644 --- a/tests/test_dm_control.py +++ b/tests/test_dm_control.py @@ -18,9 +18,13 @@ from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence +import shimmy from shimmy.dm_control_compatibility import DmControlCompatibilityV0 from shimmy.registration import DM_CONTROL_SUITE_ENVS +gym.register_envs(shimmy) + + DM_CONTROL_ENV_IDS = [ env_id for env_id in registry @@ -38,7 +42,7 @@ def test_dm_control_suite_envs(): f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "A Box observation space minimum value is -infinity. This is probably too low.", - "A Box observation space maximum value is -infinity. This is probably too high.", + "A Box observation space maximum value is infinity. This is probably too high.", "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: ()", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (8, 2)", diff --git a/tests/test_dm_lab.py b/tests/test_dm_lab.py index 0819fd62..d5112a9b 100644 --- a/tests/test_dm_lab.py +++ b/tests/test_dm_lab.py @@ -3,13 +3,16 @@ # flake8: noqa F821 import pickle +import gymnasium as gym import pytest from gymnasium.utils.env_checker import check_env, data_equivalence +import shimmy from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 -pytest.importorskip("deepmind_lab") -import deepmind_lab +gym.register_envs(shimmy) + +deepmind_lab = pytest.importorskip("deepmind_lab") LEVEL_NAMES = [ "lt_chasm", diff --git a/tests/test_gym.py b/tests/test_gym.py index e9675a8c..3a2eb739 100644 --- a/tests/test_gym.py +++ b/tests/test_gym.py @@ -9,14 +9,17 @@ from gymnasium.error import Error from gymnasium.utils.env_checker import check_env +import shimmy.openai_gym_compatibility from shimmy import GymV21CompatibilityV0, GymV26CompatibilityV0 +gymnasium.register_envs(shimmy.openai_gym_compatibility) + CHECK_ENV_IGNORE_WARNINGS = [ f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).", "A Box observation space minimum value is -infinity. This is probably too low.", - "A Box observation space maximum value is -infinity. This is probably too high.", + "A Box observation space maximum value is infinity. This is probably too high.", "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.", "The environment CartPole-v0 is out of date. You should consider upgrading to version `v1`.", ] diff --git a/tests/test_meltingpot.py b/tests/test_meltingpot.py index cc90bce5..1d874365 100644 --- a/tests/test_meltingpot.py +++ b/tests/test_meltingpot.py @@ -10,7 +10,6 @@ pytest.importorskip("meltingpot") -import meltingpot from meltingpot.substrate import SUBSTRATES from shimmy.meltingpot_compatibility import MeltingPotCompatibilityV0