Skip to content

Commit

Permalink
Remove render('human'), update seeding API, add frameskip validation
Browse files Browse the repository at this point in the history
  • Loading branch information
JesseFarebro committed Apr 15, 2022
1 parent 80de842 commit 1e67fe1
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 46 deletions.
77 changes: 43 additions & 34 deletions src/gym/envs/atari/environment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings
from typing import Optional, Union, Tuple, Dict, Any, List

import numpy as np
import gym
import gym.logger as logger

from gym import error, spaces
from gym import utils
from gym.utils import seeding

from typing import Optional, Union, Tuple, Dict, Any, List

import ale_py.roms as roms
from ale_py._ale_py import ALEInterface, ALEState, Action, LoggerMode
Expand Down Expand Up @@ -74,11 +74,22 @@ def __init__(
raise error.Error(
f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram."
)
if not (
isinstance(frameskip, int)
or (isinstance(frameskip, tuple) and len(frameskip) == 2)
):
raise error.Error(f"Invalid frameskip type: {frameskip}")

if type(frameskip) not in (int, tuple):
raise error.Error(f"Invalid frameskip type: {type(frameskip)}.")
if isinstance(frameskip, int) and frameskip <= 0:
raise error.Error(
f"Invalid frameskip of {frameskip}, frameskip must be positive.")
elif isinstance(frameskip, tuple) and len(frameskip) != 2:
raise error.Error(
f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2.")
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]:
raise error.Error(
f"Invalid stochastic frameskip, lower bound is greater than upper bound.")
elif isinstance(frameskip, tuple) and frameskip[0] <= 0:
raise error.Error(
f"Invalid stochastic frameskip lower bound is greater than upper bound.")

if render_mode is not None and render_mode not in {"rgb_array", "human"}:
raise error.Error(
f"Render mode {render_mode} not supported (rgb_array, human)."
Expand All @@ -98,7 +109,6 @@ def __init__(

# Initialize ALE
self.ale = ALEInterface()
self.viewer = None

self._game = rom_id_to_name(game)

Expand All @@ -112,7 +122,8 @@ def __init__(
# Set logger mode to error only
self.ale.setLoggerMode(LoggerMode.Error)
# Config sticky action prob.
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
self.ale.setFloat("repeat_action_probability",
repeat_action_probability)

# If render mode is human we can display screen and sound
if render_mode == "human":
Expand Down Expand Up @@ -146,7 +157,8 @@ def __init__(
low=0, high=255, dtype=np.uint8, shape=image_shape
)
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")
raise error.Error(
f"Unrecognized observation type: {self._obs_type}")

def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
"""
Expand All @@ -162,10 +174,13 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
Returns:
tuple[int, int] => (np seed, ALE seed)
"""
self.np_random, seed1 = seeding.np_random(seed)
seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31
ss = np.random.SeedSequence(seed)
seed1, seed2 = ss.generate_state(n_words=2)

self.ale.setInt("random_seed", seed2)
self.np_random = np.random.default_rng(seed1)
# ALE only takes signed integers for `setInt`, it'll get converted back
# to unsigned in StellaEnvironment.
self.ale.setInt("random_seed", seed2.astype(np.int32))

if not hasattr(roms, self._game):
raise error.Error(
Expand Down Expand Up @@ -212,7 +227,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
if isinstance(self._frameskip, int):
frameskip = self._frameskip
elif isinstance(self._frameskip, tuple):
frameskip = self.np_random.randint(*self._frameskip)
frameskip = self.np_random.integers(*self._frameskip)
else:
raise error.Error(f"Invalid frameskip type: {self._frameskip}")

Expand All @@ -224,7 +239,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
return self._get_obs(), reward, terminal, self._get_info()

def reset(
self, *, seed: Optional[int] = None, return_info: bool = False
self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[Dict[str, Any]] = None
) -> Union[Tuple[np.ndarray, Dict[str, Any]], np.ndarray]:
"""
Resets environment and returns initial observation.
Expand All @@ -247,7 +262,7 @@ def reset(
else:
return obs

def render(self, mode: str) -> None:
def render(self, mode: str) -> Any:
"""
Render is not supported by ALE. We use a paradigm similar to
Gym3 which allows you to specify `render_mode` during construction.
Expand All @@ -261,28 +276,21 @@ def render(self, mode: str) -> None:
if mode == "rgb_array":
return img
elif mode == "human":
from gym.envs.classic_control import rendering

if self.viewer is None:
logger.warn(
(
"We strongly suggest supplying `render_mode` when "
"constructing your environment, e.g., gym.make(ID, render_mode='human'). "
"Using `render_mode` provides access to proper scaling, audio support, "
"and proper framerates."
)
warnings.warn(
(
"render('human') is deprecated. Please supply `render_mode` when "
"constructing your environment, e.g., gym.make(ID, render_mode='human'). "
"The new `render_mode` keyword argument supports DPI scaling, "
"audio support, and native framerates."
)
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(img)
return self.viewer.isopen
)
return False

def close(self) -> None:
"""
Cleanup any leftovers by the environment
"""
if self.viewer is not None:
self.viewer.close()
self.viewer = None
pass

def _get_obs(self) -> np.ndarray:
"""
Expand All @@ -296,7 +304,8 @@ def _get_obs(self) -> np.ndarray:
elif self._obs_type == "grayscale":
return self.ale.getScreenGrayscale()
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")
raise error.Error(
f"Unrecognized observation type: {self._obs_type}")

def _get_info(self) -> Dict[str, Any]:
info = {
Expand Down
31 changes: 20 additions & 11 deletions tests/python/gym/test_gym_interface.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
# fmt: off
import pytest

pytest.importorskip("gym")
pytest.importorskip("gym.envs.atari")

import numpy as np

from unittest.mock import patch
from itertools import product

from gym import spaces
from gym.envs.registration import registry
from gym.core import Env
from gym.utils.env_checker import check_env

from ale_py.gym import (
register_legacy_gym_envs,
_register_gym_configs,
register_gym_envs,
)
from gym import error
from gym.utils.env_checker import check_env
from gym.core import Env
from gym.envs.registration import registry
from gym.envs.atari.environment import AtariEnv
from gym import spaces
from itertools import product
from unittest.mock import patch
import numpy as np
# fmt: on


def test_register_legacy_env_id():
Expand Down Expand Up @@ -123,7 +124,8 @@ def test_register_gym_envs(test_rom_path):
suffixes = []
versions = ["-v5"]

all_ids = set(map("".join, product(games, obs_types, suffixes, versions)))
all_ids = set(map("".join, product(
games, obs_types, suffixes, versions)))
assert all_ids.issubset(envids)


Expand Down Expand Up @@ -331,6 +333,13 @@ def test_gym_reset_with_infos(tetris_gym):
assert "rgb" in info


@pytest.mark.parametrize("frameskip", [0, -1, 4.0, (-1, 5), (0, 5), (5, 2), (1, 2, 3)])
def test_frameskip_warnings(test_rom_path, frameskip):
with patch("ale_py.roms.Tetris", create=True, new_callable=lambda: test_rom_path):
with pytest.raises(error.Error):
AtariEnv('Tetris', frameskip=frameskip)


def test_gym_compliance(tetris_gym):
try:
check_env(tetris_gym)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gym/test_legacy_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_legacy_env_specs():
"""
for spec in specs:
assert spec in registry.env_specs
kwargs = registry.env_specs[spec]._kwargs
kwargs = registry.env_specs[spec].kwargs
max_episode_steps = registry.env_specs[spec].max_episode_steps

# Assert necessary parameters are set
Expand Down

0 comments on commit 1e67fe1

Please sign in to comment.