From 28a72515d95b386ba3d0eec8262addb56ba3bfb1 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Fri, 17 Jan 2025 21:30:56 +0100 Subject: [PATCH] :Fix `get_wrapper_attr` / `set_wrapper_attr`. IMHO, the current implementation of `get_wrapper_attr` and `set_wrapper_attr` is flawed. --- gymnasium/core.py | 24 +++++++++++------------- tests/test_core.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 9dfe63876..d86a8699c 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -19,6 +19,8 @@ ActType = TypeVar("ActType") RenderFrame = TypeVar("RenderFrame") +NOT_FOUND = object() + class Env(Generic[ObsType, ActType]): r"""The main Gymnasium class for implementing Reinforcement Learning Agents environments. @@ -415,15 +417,15 @@ def get_wrapper_attr(self, name: str) -> Any: Returns: The variable with name in wrapper or lower environments """ - if hasattr(self, name): - return getattr(self, name) - else: + attr = getattr(self, name, NOT_FOUND) + if attr is NOT_FOUND: try: return self.env.get_wrapper_attr(name) except AttributeError as e: raise AttributeError( f"wrapper {self.class_name()} has no attribute {name!r}" ) from e + return attr def set_wrapper_attr(self, name: str, value: Any): """Sets an attribute on this wrapper or lower environment if `name` is already defined. @@ -432,18 +434,14 @@ def set_wrapper_attr(self, name: str, value: Any): name: The variable name value: The new variable value """ - sub_env = self.env - attr_set = False - - while attr_set is False and isinstance(sub_env, Wrapper): + sub_env = self + while isinstance(sub_env, Wrapper): if hasattr(sub_env, name): setattr(sub_env, name, value) - attr_set = True - else: - sub_env = sub_env.env - - if attr_set is False: - setattr(sub_env, name, value) + break + sub_env = sub_env.env + else: + setattr(self, name, value) def __str__(self): """Returns the wrapper name and the :attr:`env` representation string.""" diff --git a/tests/test_core.py b/tests/test_core.py index 196b64f73..5db778a55 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -205,7 +205,7 @@ def test_get_set_wrapper_attr(): with pytest.raises(AttributeError): env.unwrapped._disable_render_order_enforcing assert env.has_wrapper_attr("_disable_render_order_enforcing") - assert env.get_wrapper_attr("_disable_render_order_enforcing") is False + assert not env.get_wrapper_attr("_disable_render_order_enforcing") env.set_wrapper_attr("_disable_render_order_enforcing", True) @@ -213,7 +213,17 @@ def test_get_set_wrapper_attr(): env._disable_render_order_enforcing with pytest.raises(AttributeError): env.unwrapped._disable_render_order_enforcing - assert env.get_wrapper_attr("_disable_render_order_enforcing") is True + assert env.get_wrapper_attr("_disable_render_order_enforcing") + + # Test with top-most wrapper + env.MY_ATTRIBUTE_1 = True + assert env.get_wrapper_attr("MY_ATTRIBUTE_1") + env.set_wrapper_attr("MY_ATTRIBUTE_1", False) + assert not env.get_wrapper_attr("MY_ATTRIBUTE_1") + + # Test with non-existing attribute + env.set_wrapper_attr("MY_ATTRIBUTE_2", False) + assert hasattr(env, "MY_ATTRIBUTE_2") class TestRandomSeeding: