Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for scalars #163

Open
kmitrovicTT opened this issue Jan 13, 2025 · 1 comment
Open

Support for scalars #163

kmitrovicTT opened this issue Jan 13, 2025 · 1 comment
Assignees

Comments

@kmitrovicTT
Copy link
Contributor

kmitrovicTT commented Jan 13, 2025

Summary

I have encountered some failures while writing unit tests that stemmed from using scalars.

Opening this issue to leave an explanation for the only fix possible for now that I have.

Why jax always uses arrays?

In JAX, scalars are typically created as arrays, since JAX operates on arrays (even single-element arrays). When you use jnp.array(1.0), you're creating a one-element array rather than a raw scalar (as you might in other frameworks like NumPy or pure Python). This design choice helps to keep JAX's API consistent and efficient for operations on both scalars and larger arrays.

Example of what works (using 1-dim array)

@pytest.mark.parametrize("dtype", supported_dtypes)
def test_scalar_dtype(dtype: DTypeLike):
    def add(x: scalar) -> scalar:
        return x

    in0 = jnp.array((1,), dtype)  # Dummy 1-dim array used as input.
    run_op_test(add, [in0])

line in0 = jnp.array((1,), dtype) creates a 1-dim array with a shape (1,).

For dtype = jnp.float32 this produces

module @jit_add attributes {
  func.func public @main(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
    return %arg0 : tensor<1xf32> 
  } 
}

Example of what fails (using 0-dim array, aka scalar)

@pytest.mark.parametrize("dtype", supported_dtypes)
def test_scalar_dtype(dtype: DTypeLike):
    def add(x: scalar) -> scalar:
        return x

    in0 = jnp.array(1, dtype)  # Dummy scalar used as input.
    run_op_test(add, [in0])

line in0 = jnp.array(1, dtype) (which is the same as calling dtype(1)) creates a 0-dim array with an empty shape ().

For dtype = jnp.float32 this produces

module @jit_add attributes {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
    return %arg0 : tensor<f32> 
  } 
} 

Running the above python will cause

Unexpected XLA layout override: (XLA) DeviceLocalLayout(major_to_minor=(), _tiling=(), _sub_byte_element_size_in_bits=0) != DeviceLocalLayout(major_to_minor=(0,), _tiling=(), _sub_byte_element_size_in_bits=0)

On left side of != we have major_to_minor=(), and on the right we have major_to_minor=(0,). This seems like an indication that XLA is internally trying to treat the scalar as a 1-dim array.

ChatGPTs explanation of the problem

The issue you're encountering is related to the difference in how JAX treats scalars and arrays with different shapes and the internal layout of the data when running operations on devices like GPUs/TPUs. Let me explain the root cause and how to fix it.

The Root Cause:
Creating in0 = jnp.array(1, dtype):
This creates a scalar array (0-dimensional) of the specified dtype. In JAX, a scalar created this way is treated as an array with shape () (i.e., a 0-dimensional array). The error you're seeing comes from the fact that JAX expects the layout of a 0-dimensional array to be different from a 1-dimensional array.

In particular, jnp.array(1, dtype) creates a scalar-like array that doesn't have a clear layout defined for device operations, as it isn't shaped as (1,) or any higher-dimensional array. The XLA (Accelerated Linear Algebra) compiler that JAX uses is trying to optimize the layout of the array on the device, and it encounters a mismatch when it tries to treat a scalar array as if it were a 1-dimensional array with shape (1,).

Creating in0 = jnp.array((1,), dtype):
This explicitly creates a 1-dimensional array of shape (1,), even if it only contains one element. This means JAX knows that it's an array with one element, and XLA is able to handle the layout and optimization more gracefully, since it matches the expected 1D shape. This avoids the error you're seeing.

@kmitrovicTT kmitrovicTT self-assigned this Jan 13, 2025
@kmitrovicTT
Copy link
Contributor Author

Funny thing is that running this solely on CPU will work, even the result is a scalar (aka 0-dim array).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant