You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is working as intended, but I agree it's a bit of a strange corner case. The issue is that, across all its APIs, JAX does not implicitly convert list inputs to arrays, because when that was possible it led to silent and difficult-to-debug performance issues in practice. This is discussed at JAX sharp bits: non-array inputs.
So given this, we now have to make a decision about what == does when it encounters a JAX array and a list. One option would be to raise an error, as with jnp.equal, but this causes problems because there are situations where the Python interpreter expects __eq__ to not fail. So instead we opt to return NotImplemented from JAX, such that the equality check dispatches (in this case) via list.__eq__, and this returns False. Though perhaps raising an error would be more useful to users – I'm not sure.
Anyway, I hope that makes it clear why it's intended that JAX's behavior diverges from NumPy's behavior in this case. If you want to do array operations in JAX, you need to use arrays, not sequences like lists or tuples.
Description
The way jax implements
==
for dealing with lists and jax arrays is inconsistent with how NumPy handles it.Small example:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: