Skip to content

Commit

Permalink
Add PPO + Transformer-XL (vwxyzjn#459)
Browse files Browse the repository at this point in the history
* initial commit of ppo trxl

* removed video capture

* Switched from grayscale to RGB

* RGB obs reconstruction

* print reconstruction loss instead of total loss

* Ensure that transformer memory length is not larger than max episode steps

* fixed enjoy.py in the case of Searing Spotlights

* added video capture support again after updating memory gym

* default hyperparameters

* remove unnecessary padding from TrXL memory, if applicable

* print SPS

* updated pyproject.toml because of memory-gym 1.0.2

* fixed comment

* refactored code + added a comment

* added annealing entropy coefficient. learning rate anneals also from initial to final value.

* aligned monitoring of losses and further metrics

* slight adjustments due to pre-commit, pre-commit still fails due to unused imports, however these imports are necessary for the used environments to be registered

* heads share parameters again for reproduction purposes, set default hyperparameters for succesfull MMG training

* fixed entropy for multi-discrete action spaces

* advantage normalization is off per default, added hidden layer after TrXL

* set max_episode_steps for endless environments

* updated poetry.lock, define max episode steps for endless environments, started docs, added benchmark sh

* added Transformer-XL (PPO-TrXL) to the navigation bar, improved docs

* Added report to docs

* refactored "blocks" to "layers"

* Finalized enjoy.py,which can load pre-trained models from huggingface hub, updated docs

* pre-commit enjoy.py

* add #noqa to fix pre-commit

* fix pre-commit #noqa

* pre-commit fixed enjoy import order

* last pass of pre-commit

* fixed spelling

* remove macos-latest from .github/workflows/test.yaml

* Added ppo_trxl to README.md, fixed enjoy and ppo_trxl for MiniGrid, added proper rendering to ProofofMemory-v0, updated docs for training and enjoying MiniGrid and ProofofMemory-v0

* pre-commit fixes

* Add requirements.txt, update poetry, torch defaults to CUDA, updated docs

* updated doc links, added memory gym requirements to README.md

---------

Co-authored-by: Horrible22232 <[email protected]>
  • Loading branch information
MarcoMeter and Horrible22232 authored Sep 18, 2024
1 parent 65789ba commit 9752b32
Show file tree
Hide file tree
Showing 12 changed files with 2,896 additions and 3 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10"]
poetry-version: ["1.7"]
os: [ubuntu-22.04, macos-latest, windows-latest]
os: [ubuntu-22.04, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10"]
poetry-version: ["1.7"]
os: [ubuntu-22.04, macos-latest, windows-latest]
os: [ubuntu-22.04, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -91,7 +91,7 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10"]
poetry-version: ["1.7"]
os: [ubuntu-22.04, macos-latest, windows-latest]
os: [ubuntu-22.04, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pip install -r requirements/requirements-pettingzoo.txt
pip install -r requirements/requirements-jax.txt
pip install -r requirements/requirements-docs.txt
pip install -r requirements/requirements-cloud.txt
pip install -r requirements/requirements-memory_gym.txt
```

To run training scripts in other games:
Expand Down Expand Up @@ -140,6 +141,7 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_atari_multigpu.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_multigpupy)
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
| | [`ppo_trxl.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_trxl/ppo_trxl.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo_trxl/)
|[Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy) |
Expand Down
52 changes: 52 additions & 0 deletions benchmark/ppo_trxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# export WANDB_ENTITY=openrlbenchmark

cd cleanrl/ppo_trxl
poetry install
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids MortarMayhem-Grid-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --norm_adv --trxl_memory_length 119 --total_timesteps 100000000" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids MortarMayhem-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 275" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids MysteryPath-Grid-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 96 --total_timesteps 100000000" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids MysteryPath-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids SearingSpotlights-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids Endless-SearingSpotlights-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256 --total_timesteps 350000000" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \
--env-ids Endless-MortarMayhem-v0 Endless-MysteryPath-v0 \
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256 --total_timesteps 350000000" \
--num-seeds 3 \
--workers 32 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template
91 changes: 91 additions & 0 deletions cleanrl/ppo_trxl/enjoy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from dataclasses import dataclass

import gymnasium as gym
import torch
import tyro
from ppo_trxl import Agent, make_env


@dataclass
class Args:
hub: bool = False
"""whether to load the model from the huggingface hub or from the local disk"""
name: str = "Endless-MortarMayhem-v0_12.nn"
"""path to the model file"""


if __name__ == "__main__":
# Parse command line arguments and retrieve model path
cli_args = tyro.cli(Args)
if cli_args.hub:
try:
from huggingface_hub import hf_hub_download

path = hf_hub_download(repo_id="LilHairdy/cleanrl_memory_gym", filename=cli_args.name)
except:
raise RuntimeError(
"Cannot load model from the huggingface hub. Please install the huggingface_hub pypi package and verify the model name. You can also download the model from the hub manually and load it from disk."
)
else:
path = cli_args.name

# Load the pre-trained model and the original args used to train it
checkpoint = torch.load(path)
args = checkpoint["args"]
args = type("Args", (), args)

# Init environment and reset
env = make_env(args.env_id, 0, False, "", "human")()
obs, _ = env.reset()
env.render()

# Determine maximum episode steps
max_episode_steps = env.spec.max_episode_steps
if not max_episode_steps:
max_episode_steps = env.max_episode_steps
if max_episode_steps <= 0:
max_episode_steps = 1024 # Memory Gym envs have max_episode_steps set to -1
# May episode impacts positional encoding, so make sure to set this accordingly

# Setup agent and load its model parameters
action_space_shape = (
(env.action_space.n,) if isinstance(env.action_space, gym.spaces.Discrete) else tuple(env.action_space.nvec)
)
agent = Agent(args, env.observation_space, action_space_shape, max_episode_steps)
agent.load_state_dict(checkpoint["model_weights"])

# Setup Transformer-XL memory, mask and indices
memory = torch.zeros((1, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32)
memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1)
repetitions = torch.repeat_interleave(
torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0
).long()
memory_indices = torch.stack(
[torch.arange(i, i + args.trxl_memory_length) for i in range(max_episode_steps - args.trxl_memory_length + 1)]
).long()
memory_indices = torch.cat((repetitions, memory_indices))

# Run episode
done = False
t = 0
while not done:
# Prepare observation and memory
obs = torch.Tensor(obs).unsqueeze(0)
memory_window = memory[0, memory_indices[t].unsqueeze(0)]
t_ = max(0, min(t, args.trxl_memory_length - 1))
mask = memory_mask[t_].unsqueeze(0)
indices = memory_indices[t].unsqueeze(0)
# Forward agent
action, _, _, _, new_memory = agent.get_action_and_value(obs, memory_window, mask, indices)
memory[:, t] = new_memory
# Step
obs, reward, termination, truncation, info = env.step(action.cpu().squeeze().numpy())
env.render()
done = termination or truncation
t += 1

if "r" in info["episode"].keys():
print(f"Episode return: {info['episode']['r'][0]}, Episode length: {info['episode']['l'][0]}")
else:
print(f"Episode return: {info['reward']}, Episode length: {info['length']}")
env.close()
Loading

0 comments on commit 9752b32

Please sign in to comment.