Skip to content

Commit

Permalink
Attempt to fix type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Oct 16, 2024
1 parent be6239f commit 1e326d5
Show file tree
Hide file tree
Showing 30 changed files with 105 additions and 62 deletions.
13 changes: 8 additions & 5 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import re
import sys
from collections import defaultdict
from collections.abc import Iterable, Sequence
import collections.abc
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from enum import Enum
from types import ModuleType
from typing import Any, Callable
from typing import Any, runtime_checkable

import gymnasium as gym
from gymnasium import Env, Wrapper, error, logger
Expand Down Expand Up @@ -51,12 +52,14 @@
]


@runtime_checkable
class EnvCreator(Protocol):
"""Function type expected for an environment."""

def __call__(self, **kwargs: Any) -> Env: ...


@runtime_checkable
class VectorEnvCreator(Protocol):
"""Function type expected for an environment."""

Expand Down Expand Up @@ -100,7 +103,7 @@ class EnvSpec:
entry_point: EnvCreator | str | None = field(default=None)

# Environment attributes
reward_threshold: float | None = field(default=None)
reward_threshold: float | int | None = field(default=None)
nondeterministic: bool = field(default=False)

# Wrappers
Expand Down Expand Up @@ -570,7 +573,7 @@ def namespace(ns: str):
def register(
id: str,
entry_point: EnvCreator | str | None = None,
reward_threshold: float | None = None,
reward_threshold: float | int | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
Expand Down Expand Up @@ -836,7 +839,7 @@ def make_vec(
num_envs: int = 1,
vectorization_mode: VectorizeMode | str | None = None,
vector_kwargs: dict[str, Any] | None = None,
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
wrappers: collections.abc.Sequence[Callable[[Env], Wrapper]] | None = None,
**kwargs,
) -> gym.vector.VectorEnv:
"""Create a vector environment according to the given ID.
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/experimental/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Any, Callable, Generic, TypeVar
from typing import Any, Generic, TypeVar
from collections.abc import Callable

import numpy as np

Expand Down
3 changes: 1 addition & 2 deletions gymnasium/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import collections.abc
import typing
from collections.abc import KeysView, Sequence
from typing import Any

Expand All @@ -12,7 +11,7 @@
from gymnasium.spaces.space import Space


class Dict(Space[dict[str, Any]], typing.Mapping[str, Space[Any]]):
class Dict(Space[dict[str, Any]], collections.abc.Mapping[str, Space[Any]]):
"""A dictionary of :class:`Space` instances.
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class GraphInstance(NamedTuple):
* edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
"""

nodes: NDArray[Any]
edges: NDArray[Any] | None
edge_links: NDArray[Any] | None
nodes: NDArray
edges: NDArray | None
edge_links: NDArray | None


class Graph(Space[GraphInstance]):
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MultiBinary(Space[NDArray[np.int8]]):

def __init__(
self,
n: NDArray[np.integer[Any]] | Sequence[int] | int,
n: NDArray[np.integer] | Sequence[int] | int,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiBinary` space.
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/oneof.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

import typing
import collections.abc
from collections.abc import Iterable
from typing import Any

Expand Down Expand Up @@ -34,7 +34,7 @@ class OneOf(Space[Any]):
def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
seed: int | collections.abc.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`OneOf` space.
Expand Down Expand Up @@ -143,7 +143,7 @@ def __repr__(self) -> str:
return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[int, Any]]
self, sample_n: collections.abc.Sequence[tuple[int, Any]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import typing
from typing import Any, Union
import collections.abc

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -186,7 +186,7 @@ def __repr__(self) -> str:
return f"Sequence({self.feature_space}, stack={self.stack})"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...] | Any]
self, sample_n: collections.abc.Sequence[tuple[Any, ...] | Any]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
if self.stack:
Expand Down
10 changes: 5 additions & 5 deletions gymnasium/spaces/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from __future__ import annotations

import typing
from collections.abc import Iterable
import collections.abc
from typing import Any

import numpy as np

from gymnasium.spaces.space import Space


class Tuple(Space[tuple[Any, ...]], typing.Sequence[Any]):
class Tuple(Space[tuple[Any, ...]], collections.abc.Sequence[Any]):
"""A tuple (more precisely: the cartesian product) of :class:`Space` instances.
Elements of this space are tuples of elements of the constituent spaces.
Expand All @@ -26,7 +26,7 @@ class Tuple(Space[tuple[Any, ...]], typing.Sequence[Any]):
def __init__(
self,
spaces: Iterable[Space[Any]],
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
seed: int | collections.abc.Sequence[int] | np.random.Generator | None = None,
):
r"""Constructor of :class:`Tuple` space.
Expand All @@ -48,7 +48,7 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)

def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
def seed(self, seed: int | collections.abc.Sequence[int] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces.
Depending on the type of seed, the subspaces will be seeded differently
Expand Down Expand Up @@ -131,7 +131,7 @@ def __repr__(self) -> str:
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"

def to_jsonable(
self, sample_n: typing.Sequence[tuple[Any, ...]]
self, sample_n: collections.abc.Sequence[tuple[Any, ...]]
) -> list[list[Any]]:
"""Convert a batch of samples from this space to a JSONable data type."""
# serialize as list-repr of tuple of vectors
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:

@unflatten.register(MultiDiscrete)
def _unflatten_multidiscrete(
space: MultiDiscrete, x: NDArray[np.integer[Any]]
) -> NDArray[np.integer[Any]]:
space: MultiDiscrete, x: NDArray[np.integer]
) -> NDArray[np.integer]:
offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
offsets[1:] = np.cumsum(space.nvec.flatten())
nonzero = np.nonzero(x)
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/passive_env_checker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""A set of functions for passively checking environment implementations."""

import inspect
from collections.abc import Callable
from functools import partial
from typing import Callable

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/performance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A collection of runtime performance bencharks, useful for debugging performance related issues."""

import time
from typing import Callable
from collections.abc import Callable

import gymnasium

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from collections import deque
from typing import Callable
from collections.abc import Callable

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/save_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import os
from typing import Callable
from collections.abc import Callable

import gymnasium as gym
from gymnasium import logger
Expand Down
12 changes: 7 additions & 5 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import sys
import time
import traceback
from collections.abc import Sequence
import collections.abc
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable
from multiprocessing.sharedctypes import SynchronizedArray
from typing import Any
from collections.abc import Callable

import numpy as np

Expand Down Expand Up @@ -90,7 +92,7 @@ class AsyncVectorEnv(VectorEnv):

def __init__(
self,
env_fns: Sequence[Callable[[], Env]],
env_fns: collections.abc.Sequence[Callable[[], Env]],
shared_memory: bool = True,
copy: bool = True,
context: str | None = None,
Expand Down Expand Up @@ -684,10 +686,10 @@ def __del__(self):

def _async_worker(
index: int,
env_fn: callable,
env_fn: Callable,
pipe: Connection,
parent_pipe: Connection,
shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...],
shared_memory: SynchronizedArray | dict[str, Any] | tuple[Any, ...],
error_queue: Queue,
):
env = env_fn()
Expand Down
8 changes: 5 additions & 3 deletions gymnasium/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from __future__ import annotations

from collections.abc import Iterator, Sequence
import collections.abc
from collections.abc import Iterator
from copy import deepcopy
from typing import Any, Callable
from typing import Any
from collections.abc import Callable

import numpy as np

Expand Down Expand Up @@ -63,7 +65,7 @@ class SyncVectorEnv(VectorEnv):

def __init__(
self,
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
env_fns: Iterator[Callable[[], Env]] | collections.abc.Sequence[Callable[[], Env]],
copy: bool = True,
observation_mode: str | Space = "same",
):
Expand Down
7 changes: 4 additions & 3 deletions gymnasium/vector/utils/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import multiprocessing as mp
from multiprocessing.sharedctypes import SynchronizedArray
from ctypes import c_bool
from functools import singledispatch
from typing import Any
Expand Down Expand Up @@ -32,7 +33,7 @@
@singledispatch
def create_shared_memory(
space: Space[Any], n: int = 1, ctx=mp
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
) -> dict[str, Any] | tuple[Any, ...] | SynchronizedArray:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Expand Down Expand Up @@ -109,7 +110,7 @@ def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):

@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
space: Space, shared_memory: dict[str, Any] | tuple[Any, ...] | SynchronizedArray, n: int = 1
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
Expand Down Expand Up @@ -209,7 +210,7 @@ def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
shared_memory: dict[str, Any] | tuple[Any, ...] | SynchronizedArray,
):
"""Write the observation of a single environment into shared memory.
Expand Down
8 changes: 4 additions & 4 deletions gymnasium/vector/utils/space_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from __future__ import annotations

import typing
from collections.abc import Iterable, Iterator
import collections.abc
from collections.abc import Iterable, Iterator, Callable
from copy import deepcopy
from functools import singledispatch
from typing import Any
Expand Down Expand Up @@ -149,7 +149,7 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):


@singledispatch
def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space:
def batch_differing_spaces(spaces: collections.abc.Sequence[Space]) -> Space:
"""Batch a Sequence of spaces where subspaces to contain minor differences.
Args:
Expand Down Expand Up @@ -429,7 +429,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any,

@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
space: Space, n: int = 1, fn: Callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from __future__ import annotations

import os
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Callable, SupportsFloat
from typing import Any, SupportsFloat

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/wrappers/transform_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Callable
from collections.abc import Callable

import numpy as np

Expand Down
Loading

0 comments on commit 1e326d5

Please sign in to comment.