You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have encountered some failures while writing unit tests that stemmed from using scalars.
Opening this issue to leave an explanation for the only fix possible for now that I have.
Why jax always uses arrays?
In JAX, scalars are typically created as arrays, since JAX operates on arrays (even single-element arrays). When you use jnp.array(1.0), you're creating a one-element array rather than a raw scalar (as you might in other frameworks like NumPy or pure Python). This design choice helps to keep JAX's API consistent and efficient for operations on both scalars and larger arrays.
Example of what works (using 1-dim array)
@pytest.mark.parametrize("dtype", supported_dtypes)deftest_scalar_dtype(dtype: DTypeLike):
defadd(x: scalar) ->scalar:
returnxin0=jnp.array((1,), dtype) # Dummy 1-dim array used as input.run_op_test(add, [in0])
line in0 = jnp.array((1,), dtype) creates a 1-dim array with a shape (1,).
On left side of != we have major_to_minor=(), and on the right we have major_to_minor=(0,). This seems like an indication that XLA is internally trying to treat the scalar as a 1-dim array.
ChatGPTs explanation of the problem
The issue you're encountering is related to the difference in how JAX treats scalars and arrays with different shapes and the internal layout of the data when running operations on devices like GPUs/TPUs. Let me explain the root cause and how to fix it.
The Root Cause:
Creating in0 = jnp.array(1, dtype):
This creates a scalar array (0-dimensional) of the specified dtype. In JAX, a scalar created this way is treated as an array with shape () (i.e., a 0-dimensional array). The error you're seeing comes from the fact that JAX expects the layout of a 0-dimensional array to be different from a 1-dimensional array.
In particular, jnp.array(1, dtype) creates a scalar-like array that doesn't have a clear layout defined for device operations, as it isn't shaped as (1,) or any higher-dimensional array. The XLA (Accelerated Linear Algebra) compiler that JAX uses is trying to optimize the layout of the array on the device, and it encounters a mismatch when it tries to treat a scalar array as if it were a 1-dimensional array with shape (1,).
Creating in0 = jnp.array((1,), dtype):
This explicitly creates a 1-dimensional array of shape (1,), even if it only contains one element. This means JAX knows that it's an array with one element, and XLA is able to handle the layout and optimization more gracefully, since it matches the expected 1D shape. This avoids the error you're seeing.
The text was updated successfully, but these errors were encountered:
Summary
I have encountered some failures while writing unit tests that stemmed from using scalars.
Opening this issue to leave an explanation for the only fix possible for now that I have.
Why jax always uses arrays?
In JAX, scalars are typically created as arrays, since JAX operates on arrays (even single-element arrays). When you use
jnp.array(1.0)
, you're creating a one-element array rather than a raw scalar (as you might in other frameworks like NumPy or pure Python). This design choice helps to keep JAX's API consistent and efficient for operations on both scalars and larger arrays.Example of what works (using 1-dim array)
line
in0 = jnp.array((1,), dtype)
creates a 1-dim array with a shape(1,)
.For
dtype = jnp.float32
this producesExample of what fails (using 0-dim array, aka scalar)
line
in0 = jnp.array(1, dtype)
(which is the same as callingdtype(1)
) creates a 0-dim array with an empty shape()
.For
dtype = jnp.float32
this producesRunning the above python will cause
On left side of
!=
we havemajor_to_minor=()
, and on the right we havemajor_to_minor=(0,)
. This seems like an indication that XLA is internally trying to treat the scalar as a 1-dim array.ChatGPTs explanation of the problem
The text was updated successfully, but these errors were encountered: