diff --git a/gymnasium/spaces/multi_discrete.py b/gymnasium/spaces/multi_discrete.py index 2e6592053..86ba399c8 100644 --- a/gymnasium/spaces/multi_discrete.py +++ b/gymnasium/spaces/multi_discrete.py @@ -59,6 +59,19 @@ def __init__( seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. start: Optionally, the starting value the element of each class will take (defaults to 0). """ + # determine dtype + if dtype is None: + raise ValueError( + "MultiDiscrete dtype must be explicitly provided, cannot be None." + ) + self.dtype = np.dtype(dtype) + + # * check that dtype is an accepted dtype + if not (np.issubdtype(self.dtype, np.integer)): + raise ValueError( + f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype" + ) + self.nvec = np.array(nvec, dtype=dtype, copy=True) if start is not None: self.start = np.array(start, dtype=dtype, copy=True) @@ -70,7 +83,7 @@ def __init__( ), "start and nvec (counts) should have the same shape" assert (self.nvec > 0).all(), "nvec (counts) have to be positive" - super().__init__(self.nvec.shape, dtype, seed) + super().__init__(self.nvec.shape, self.dtype, seed) @property def shape(self) -> tuple[int, ...]: