A performant Jax implementation of Bootstrapped Model Predictive Control (BMPC).
To install the dependencies for this project (tested on Ubuntu 22.04), run
pip install -U numpy tqdm "flax[all]" optax jaxtyping einops gymnasium[mujoco] hydra-core tensorboard orbax-checkpoint dm_control tensorflow tensorflow-probability tf-keras
pip install -U "jax[cuda12]"
Install the package from the base directory with
pip install -e .
Then, edit config.yaml
and run train.py
in the main project directory. Some examples:
# gymnasium
python train.py env.backend=gymnasium env.env_id=HalfCheetah-v4
# dmcs
python train.py env.backend=dmc env.env_id=cheetah-run
Special thanks to @wertyuilife2 for their contributions to this repository!