Skip to content

Commit

Permalink
:Fix get_wrapper_attr / set_wrapper_attr.
Browse files Browse the repository at this point in the history
IMHO, the current implementation of `get_wrapper_attr` and `set_wrapper_attr` is flawed.
  • Loading branch information
duburcqa committed Jan 17, 2025
1 parent eaccbb5 commit 28a7251
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
24 changes: 11 additions & 13 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
ActType = TypeVar("ActType")
RenderFrame = TypeVar("RenderFrame")

NOT_FOUND = object()


class Env(Generic[ObsType, ActType]):
r"""The main Gymnasium class for implementing Reinforcement Learning Agents environments.
Expand Down Expand Up @@ -415,15 +417,15 @@ def get_wrapper_attr(self, name: str) -> Any:
Returns:
The variable with name in wrapper or lower environments
"""
if hasattr(self, name):
return getattr(self, name)
else:
attr = getattr(self, name, NOT_FOUND)
if attr is NOT_FOUND:
try:
return self.env.get_wrapper_attr(name)
except AttributeError as e:
raise AttributeError(
f"wrapper {self.class_name()} has no attribute {name!r}"
) from e
return attr

def set_wrapper_attr(self, name: str, value: Any):
"""Sets an attribute on this wrapper or lower environment if `name` is already defined.
Expand All @@ -432,18 +434,14 @@ def set_wrapper_attr(self, name: str, value: Any):
name: The variable name
value: The new variable value
"""
sub_env = self.env
attr_set = False

while attr_set is False and isinstance(sub_env, Wrapper):
sub_env = self
while isinstance(sub_env, Wrapper):
if hasattr(sub_env, name):
setattr(sub_env, name, value)
attr_set = True
else:
sub_env = sub_env.env

if attr_set is False:
setattr(sub_env, name, value)
break
sub_env = sub_env.env
else:
setattr(self, name, value)

def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
Expand Down
14 changes: 12 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,25 @@ def test_get_set_wrapper_attr():
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.has_wrapper_attr("_disable_render_order_enforcing")
assert env.get_wrapper_attr("_disable_render_order_enforcing") is False
assert not env.get_wrapper_attr("_disable_render_order_enforcing")

env.set_wrapper_attr("_disable_render_order_enforcing", True)

with pytest.raises(AttributeError):
env._disable_render_order_enforcing
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.get_wrapper_attr("_disable_render_order_enforcing") is True
assert env.get_wrapper_attr("_disable_render_order_enforcing")

# Test with top-most wrapper
env.MY_ATTRIBUTE_1 = True
assert env.get_wrapper_attr("MY_ATTRIBUTE_1")
env.set_wrapper_attr("MY_ATTRIBUTE_1", False)
assert not env.get_wrapper_attr("MY_ATTRIBUTE_1")

# Test with non-existing attribute
env.set_wrapper_attr("MY_ATTRIBUTE_2", False)
assert hasattr(env, "MY_ATTRIBUTE_2")


class TestRandomSeeding:
Expand Down

0 comments on commit 28a7251

Please sign in to comment.