BatchJAX Description BatchJAX is a library that allow JAX vmap to be used over lists and objax.ModuleList. Installation pip install batchjax Example See batchjax_example.ipynb.