You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* most axes (2 of them) had size 2, e.g. axis 0 of argument x1 of type float32[2];
* one axis had size 1: axis 0 of argument x2 of type float32[2]
This is wrong. x2 is not the argument that has size 1 (that would be a3).
System info (python version, jaxlib version, accelerator, etc.)
Description
Fails with:
This is wrong.
x2
is not the argument that has size 1 (that would bea3
).System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='5c2409a38508', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: