diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bfbfb30e..6c43ed90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,8 +11,6 @@ repos: rev: v1.1.379 hooks: - id: pyright - # must match the Python version used in CI - language_version: python3.11 additional_dependencies: [ beartype, @@ -21,7 +19,5 @@ repos: jaxtyping, optax, pytest, - tensorflow, - tf2onnx, typing_extensions, ] diff --git a/equinox/_enum.py b/equinox/_enum.py index 1716a60c..b18adbdc 100644 --- a/equinox/_enum.py +++ b/equinox/_enum.py @@ -140,7 +140,7 @@ def __instancecheck__(cls, value): class EnumerationItem(Module): - _value: Int[Union[Array, np.ndarray], ""] + _value: Int[Union[Array, np.ndarray[Any, np.dtype[np.signedinteger]]], ""] # Should have annotation `"type[Enumeration]"`, but this fails due to beartype bug # #289. _enumeration: Any = field(static=True) diff --git a/equinox/_filters.py b/equinox/_filters.py index 13ee1fb5..6a219a56 100644 --- a/equinox/_filters.py +++ b/equinox/_filters.py @@ -37,7 +37,7 @@ def is_inexact_array(element: Any) -> bool: array. """ if isinstance(element, (np.ndarray, np.generic)): - return np.issubdtype(element.dtype, np.inexact) + return bool(np.issubdtype(element.dtype, np.inexact)) elif isinstance(element, jax.Array): return jnp.issubdtype(element.dtype, jnp.inexact) else: @@ -51,7 +51,7 @@ def is_inexact_array_like(element: Any) -> bool: if hasattr(element, "__jax_array__"): element = element.__jax_array__() if isinstance(element, (np.ndarray, np.generic)): - return np.issubdtype(element.dtype, np.inexact) + return bool(np.issubdtype(element.dtype, np.inexact)) elif isinstance(element, jax.Array): return jnp.issubdtype(element.dtype, jnp.inexact) else: diff --git a/equinox/internal/_onnx.py b/equinox/internal/_onnx.py index ae31263d..7c7d76ca 100644 --- a/equinox/internal/_onnx.py +++ b/equinox/internal/_onnx.py @@ -24,8 +24,8 @@ def f(x, y): ``` """ import jax.experimental.jax2tf as jax2tf - import tensorflow as tf - import tf2onnx + import tensorflow as tf # pyright: ignore[reportMissingImports] + import tf2onnx # pyright: ignore[reportMissingImports] def _to_onnx(*args): finalised_fn = finalise_fn(fn)