Skip to content

ShaneFlandermeyer/bmpc-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

bmpc-jax

A performant Jax implementation of Bootstrapped Model Predictive Control (BMPC).

Dependencies

To install the dependencies for this project (tested on Ubuntu 22.04), run

pip install -U numpy tqdm "flax[all]" optax jaxtyping einops gymnasium[mujoco] hydra-core tensorboard orbax-checkpoint dm_control tensorflow tensorflow-probability tf-keras
pip install -U "jax[cuda12]"

Installation

Install the package from the base directory with

pip install -e .

Usage

Then, edit config.yaml and run train.py in the main project directory. Some examples:

# gymnasium 
python train.py env.backend=gymnasium env.env_id=HalfCheetah-v4 
# dmcs
python train.py env.backend=dmc env.env_id=cheetah-run   

Acknowledgements

Special thanks to @wertyuilife2 for their contributions to this repository!

About

No description or website provided.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages