Skip to content

Latest commit

 

History

History
145 lines (75 loc) · 11.4 KB

RESOURCES.md

File metadata and controls

145 lines (75 loc) · 11.4 KB

PureJaxRL Resources

Last year, I released PureJaxRL, a simple repository that implements RL algorithms entirely end-to-end in JAX, which enables speedups of up to 4000x in RL training. PureJaxRL, in turn, was inspired by multiple projects, including CleanRL and Gymnax. Since the release of PureJaxRL, a large number of projects related to or inspired by PureJaxRL have come out, vastly expanding its use case from standard single-agent RL settings. This curated list contains those projects alongside other relevant implementations of algorithms, environments, tools, and tutorials.

To understand more about the benefits PureJaxRL, I recommend viewing the original blog post or tweet thread.

The PureJaxRL repository can be found here:

https://github.com/luchris429/purejaxrl/.

The format of the list is from awesome and awesome-jax. While this list is curated, it is certainly not complete. If you have a repository you would like to add, please contribute!

If you find this resource useful, please star the repo! It helps establish and grow the end-to-end JAX RL community.

Contents

Algorithms

End-to-End JAX RL Implementations

  • purejaxrl - Classic and simple end-to-end RL training in pure JAX.

  • rejax - Modular and importable end-to-end JAX RL training.

  • Stoix - End-to-end JAX RL training with advanced logging, configs, and more.

  • purejaxql - Simple single-file end-to-end JAX baselines for Q-Learning.

  • jym - Educational and beginner-friendly end-to-end JAX RL training.

Jax RL (But Not End-to-End) Repos

  • cleanrl - Clean implementations of RL Algorithms (in both PyTorch and JAX!).

  • jaxrl - JAX implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

  • rlbase - Single-file JAX implementations of Deep RL algorithms.

Multi-Agent RL

  • JaxMARL - Multi-Agent RL Algorithms and Environments in pure JAX.

  • Mava - Multi-Agent RL Algorithms in pure JAX (previously tensorflow-based algorithms).

  • pax - Scalable Opponent Shaping Algorithms in pure JAX.

Offline RL

  • JAX-CORL - Single-file implementations of offline RL algorithms in JAX.

Inverse-RL

  • jaxirl - Pure JAX for Inverse Reinforcement Learning.

Unsupervised Environment Design

  • minimax - Canonical implementations of UED algorithms in pure JAX, including SSM-based acceleration.

  • jaxued - Single-file implementations of UED algorithms in pure JAX.

Quality-Diversity

  • QDax - Quality-Diversity algorithms in pure JAX.

Partially-Observed RL

  • popjaxrl - Partially-observed RL environments (POPGym) and architectures (incl. SSM's) in pure JAX.

Meta-Learning RL Objectives

Environments

  • gymnax - Classic RL environments in JAX.

  • brax - Continuous control environments in JAX.

  • JaxMARL - Multi-agent algorithms and environments in pure JAX.

  • jumanji - Suite of unique RL environments in JAX.

  • pgx - Suite of popular board games in JAX.

  • popjaxrl - Partially-observed RL environments (POPGym) in JAX.

  • waymax - Self-driving car simulator in JAX.

  • Craftax - A challenging crafter-like and nethack-inspired benchmark in JAX.

  • xland-minigrid - A large-scale meta-RL environment in JAX.

  • navix - Classic minigrid environments in JAX.

  • autoverse - A fast, evolvable description language for reinforcement learning environments.

  • qdx - Quantum Error Corection with JAX.

  • matrax - Matrix games in JAX.

  • AlphaTrade - Limit Order Book (LOB) in JAX.

Relevant Tools and Components

  • evosax - Evolution strategies in JAX.

  • evojax - Evolution strategies in JAX.

  • flashbax - Accelerated replay buffers in JAX.

  • dejax - Accelerated replay buffers in JAX.

  • rlax - RL components and building blocks in JAX.

  • mctx - Monte Carlo tree searh in JAX.

  • distrax - Distributions in JAX.

  • optax - Gradient-based optimizers in JAX.

  • flax - Neural Networks in JAX.

Tutorials and Blog Posts