Skip to content

Commit

Permalink
Add __getattr__ to Gym compatibility env, add registration of dm-la…
Browse files Browse the repository at this point in the history
…b compatibility environment. (#22)
  • Loading branch information
pseudo-rnd-thoughts authored Dec 14, 2022
1 parent 3e7ce13 commit 6e82d3d
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 43 deletions.
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_version():
"atari": ["ale-py~=0.8.0"],
# "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts
"dm-control": ["dm-control>=1.0.8", "imageio", "h5py>=3.7.0"],
"dm-control-multi-agent": ["dm-control>=1.0.8", "pettingzoo>=1.22"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
Expand All @@ -63,10 +64,11 @@ def get_version():
tests_require=extras["testing"],
extras_require=extras,
classifiers=[
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
Expand Down
43 changes: 33 additions & 10 deletions shimmy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,52 @@
"""API for converting popular non-gymnasium environments to a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any

from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
from shimmy.openai_gym_compatibility import GymV22CompatibilityV0, GymV26CompatibilityV0

__version__ = "0.2.0"


class NotInstallClass:
"""Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy."""

def __init__(self, install_message: str, import_exception: ImportError):
self.install_message = install_message
self.import_exception = import_exception

def __call__(self, *args: list[Any], **kwargs: Any):
"""Acts like the `__init__` for the class."""
raise ImportError(self.install_message) from self.import_exception


try:
from shimmy.dm_control_compatibility import DmControlCompatibilityV0
except ImportError:
pass
except ImportError as e:
DmControlCompatibilityV0 = NotInstallClass(
"Dm-control is not installed, run `pip install 'shimmy[dm-control]'`", e
)


try:
from shimmy.dm_control_multiagent_compatibility import (
DmControlMultiAgentCompatibilityV0,
)
except ImportError:
pass
except ImportError as e:
DmControlMultiAgentCompatibilityV0 = NotInstallClass(
"Dm-control or Pettingzoo is not installed, run `pip install 'shimmy[dm-control-multi-agent]'`",
e,
)

try:
from shimmy.openspiel_compatibility import OpenspielCompatibilityV0
except ImportError:
pass
except ImportError as e:
OpenspielCompatibilityV0 = NotInstallClass(
"Openspiel or Pettingzoo is not installed, run `pip install 'shimmy[openspiel]'`",
e,
)

try:
from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
except ImportError:
pass

__all__ = [
"DmControlCompatibilityV0",
Expand Down
13 changes: 11 additions & 2 deletions shimmy/atari_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
"""ALE-py interface for atari."""
"""ALE-py interface for atari.
This file was originally copied from https://github.com/mgbellemare/Arcade-Learning-Environment/blob/master/src/python/env/gym.py
Under the GNU General Public License v2.0
Copyright is held by the authors
Changes
* Added `self.render_mode` which is identical to `self._render_mode`
"""
from __future__ import annotations

import sys
Expand Down Expand Up @@ -29,7 +38,7 @@ class AtariEnvStepMetadata(TypedDict):
seeds: NotRequired[Sequence[int]]


class AtariEnv(gymnasium.Env[np.ndarray, int], EzPickle):
class AtariEnv(gymnasium.Env[np.ndarray, np.int64], EzPickle):
"""(A)rcade (L)earning (Gymnasium) (Env)ironment.
A Gymnasium wrapper around the Arcade Learning Environment (ALE).
Expand Down
6 changes: 3 additions & 3 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def reset(

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

return obs, info # pyright: ignore[reportGeneralTypeIssues]
return obs, info

def step(
self, action: np.ndarray
Expand All @@ -95,9 +95,9 @@ def step(
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

if self.render_mode == "human":
self.viewer.render()
self.viewer.render(self.render_mode)

return ( # pyright: ignore[reportGeneralTypeIssues]
return (
obs,
reward,
terminated,
Expand Down
2 changes: 1 addition & 1 deletion shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def step(self, actions):
)

if self.render_mode == "human":
self.viewer.render()
self.viewer.render(self.render_mode)

if any(terminations.values()) or any(truncations.values()):
self.agents = []
Expand Down
14 changes: 7 additions & 7 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any, TypeVar
from typing import Any, Dict

import gymnasium
import gymnasium as gym
import numpy as np
from gymnasium.core import ObsType

from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space


class DmLabCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class DmLabCompatibilityV0(gym.Env[ObsType, Dict[str, np.ndarray]]):
"""A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment."""

metadata = {"render_modes": [], "render_fps": 10}
Expand Down Expand Up @@ -45,22 +45,22 @@ def reset(
return (
self._env.observations(),
info,
) # pyright: ignore[reportGeneralTypeIssues]
)

def step(
self, action: dict[str, np.ndarray]
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-lab environment."""
# there's some funky quantization happening here, dm_lab only accepts ints as actions
action = np.array([a[0] for a in action.values()], dtype=np.intc)
reward = self._env.step(action)
action_array = np.array([a[0] for a in action.values()], dtype=np.intc)
reward = self._env.step(action_array)

obs = self._env.observations()
terminated = not self._env.is_running()
truncated = False
info = {}

return ( # pyright: ignore[reportGeneralTypeIssues]
return (
obs,
reward,
terminated,
Expand Down
18 changes: 12 additions & 6 deletions shimmy/openai_gym_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
GYM_IMPORT_ERROR = None


class GymV26Compatibility(gymnasium.Env[ObsType, ActType]):
"""Converts a gym v26 environment to a gymnasium environment."""
class GymV26CompatibilityV0(gymnasium.Env[ObsType, ActType]):
"""Converts a Gym v26 environment to a Gymnasium environment."""

def __init__(
self,
Expand Down Expand Up @@ -83,6 +83,10 @@ def __init__(
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)

def __getattr__(self, item: str):
"""Gets an attribute that only exists in the base environments."""
return getattr(self.gym_env, item)

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[ObsType, dict]:
Expand Down Expand Up @@ -151,7 +155,7 @@ def seed(self, seed: int | None = None):
...


class GymV22Compatibility(gymnasium.Env[ObsType, ActType]):
class GymV22CompatibilityV0(gymnasium.Env[ObsType, ActType]):
r"""A wrapper which can transform an environment from the old API to the new API.
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
Expand Down Expand Up @@ -201,6 +205,10 @@ def __init__(

self.gym_env: LegacyV22Env = gym_env

def __getattr__(self, item: str):
"""Gets an attribute that only exists in the base environments."""
return getattr(self.gym_env, item)

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[ObsType, dict]:
Expand Down Expand Up @@ -236,9 +244,7 @@ def step(self, action: ActType) -> tuple[Any, float, bool, bool, dict]:
if self.render_mode == "human":
self.render()

return convert_to_terminated_truncated_step_api(
(obs, reward, done, info)
) # pyright: ignore[reportGeneralTypeIssues]
return convert_to_terminated_truncated_step_api((obs, reward, done, info))

def render(self) -> Any:
"""Renders the environment.
Expand Down
28 changes: 23 additions & 5 deletions shimmy/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,32 @@ def _register_atari_envs():
)


def _register_dm_lab():
try:
import deepmind_lab
except ImportError:
return

from shimmy.dm_lab_compatibility import DmLabCompatibilityV0

def _make_dm_lab_env(
env_id: str, observations, config: dict[str, Any], renderer: str
):
env = deepmind_lab.Lab(env_id, observations, config=config, renderer=renderer)
return DmLabCompatibilityV0(env)

register("DmLabCompatibility-v0", _make_dm_lab_env)


def register_gymnasium_envs():
"""This function is called when gymnasium is imported."""
_register_dm_control_envs()
_register_atari_envs()

register(
"GymV26Environment-v0", "shimmy.openai_gym_compatibility:GymV26Compatibility"
"GymV26Environment-v0", "shimmy.openai_gym_compatibility:GymV26CompatibilityV0"
)
register(
"GymV22Environment-v0", "shimmy.openai_gym_compatibility:GymV22Compatibility"
"GymV22Environment-v0", "shimmy.openai_gym_compatibility:GymV22CompatibilityV0"
)

_register_dm_control_envs()
_register_atari_envs()
_register_dm_lab()
5 changes: 2 additions & 3 deletions shimmy/utils/dm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
from dm_env.specs import Array, BoundedArray, DiscreteArray
from gymnasium import spaces
from gymnasium.core import ObsType


def dm_spec2gym_space(spec) -> spaces.Space[Any]:
Expand Down Expand Up @@ -54,7 +53,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:

def dm_control_step2gym_step(
timestep,
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
) -> tuple[Any, float, bool, bool, dict[str, Any]]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0
Expand All @@ -72,7 +71,7 @@ def dm_control_step2gym_step(
"timestep.step_type": timestep.step_type,
}

return ( # pyright: ignore[reportGeneralTypeIssues]
return (
obs,
reward,
terminated,
Expand Down
39 changes: 36 additions & 3 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import warnings

import gym
import gym as openai_gym
import gymnasium
import pytest
from gym.spaces import Box as openai_Box
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env

from shimmy import GymV22CompatibilityV0, GymV26CompatibilityV0

CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
Expand All @@ -21,7 +24,7 @@
# We do not test Atari environment's here because we check all variants of Pong in test_envs.py (There are too many Atari environments)
CLASSIC_CONTROL_ENVS = [
env_id
for env_id, spec in gym.envs.registry.items() # pyright: ignore[reportGeneralTypeIssues]
for env_id, spec in openai_gym.envs.registry.items() # pyright: ignore[reportGeneralTypeIssues]
if ("classic_control" in spec.entry_point)
]

Expand Down Expand Up @@ -51,9 +54,11 @@ def test_gym_conversion_by_id(env_id):
)
def test_gym_conversion_instantiated(env_id):
"""Tests that the gym conversion works with an instantiated gym environment."""
env = gym.make(env_id)
env = openai_gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env).unwrapped

print("render-mode", env.render_mode)
print("render-modes", env.metadata)
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env, skip_render_check=True)

Expand All @@ -65,3 +70,31 @@ def test_gym_conversion_instantiated(env_id):
raise Error(f"Unexpected warning: {warning.message}")

env.close()


class EnvWithData(openai_gym.Env):
"""Environment with data that users might want to access."""

def __init__(self):
"""Initialises the environment with hidden data."""
self.observation_space = openai_Box(low=0, high=1)
self.action_space = openai_Box(low=0, high=1)

self.data = 123

def get_env_data(self):
"""Gets the environment data."""
return self.data


def test_compatibility_get_attr():
"""Tests that the compatibility environment works with `__getattr__` for those attributes."""
env = GymV22CompatibilityV0(env=EnvWithData())
assert env.data == 123
assert env.get_env_data() == 123
env.close()

env = GymV26CompatibilityV0(env=EnvWithData())
assert env.data == 123
assert env.get_env_data() == 123
env.close()

0 comments on commit 6e82d3d

Please sign in to comment.