forked from vwxyzjn/cleanrl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PPO + Transformer-XL (vwxyzjn#459)
* initial commit of ppo trxl * removed video capture * Switched from grayscale to RGB * RGB obs reconstruction * print reconstruction loss instead of total loss * Ensure that transformer memory length is not larger than max episode steps * fixed enjoy.py in the case of Searing Spotlights * added video capture support again after updating memory gym * default hyperparameters * remove unnecessary padding from TrXL memory, if applicable * print SPS * updated pyproject.toml because of memory-gym 1.0.2 * fixed comment * refactored code + added a comment * added annealing entropy coefficient. learning rate anneals also from initial to final value. * aligned monitoring of losses and further metrics * slight adjustments due to pre-commit, pre-commit still fails due to unused imports, however these imports are necessary for the used environments to be registered * heads share parameters again for reproduction purposes, set default hyperparameters for succesfull MMG training * fixed entropy for multi-discrete action spaces * advantage normalization is off per default, added hidden layer after TrXL * set max_episode_steps for endless environments * updated poetry.lock, define max episode steps for endless environments, started docs, added benchmark sh * added Transformer-XL (PPO-TrXL) to the navigation bar, improved docs * Added report to docs * refactored "blocks" to "layers" * Finalized enjoy.py,which can load pre-trained models from huggingface hub, updated docs * pre-commit enjoy.py * add #noqa to fix pre-commit * fix pre-commit #noqa * pre-commit fixed enjoy import order * last pass of pre-commit * fixed spelling * remove macos-latest from .github/workflows/test.yaml * Added ppo_trxl to README.md, fixed enjoy and ppo_trxl for MiniGrid, added proper rendering to ProofofMemory-v0, updated docs for training and enjoying MiniGrid and ProofofMemory-v0 * pre-commit fixes * Add requirements.txt, update poetry, torch defaults to CUDA, updated docs * updated doc links, added memory gym requirements to README.md --------- Co-authored-by: Horrible22232 <[email protected]>
- Loading branch information
1 parent
65789ba
commit 9752b32
Showing
12 changed files
with
2,896 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# export WANDB_ENTITY=openrlbenchmark | ||
|
||
cd cleanrl/ppo_trxl | ||
poetry install | ||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MortarMayhem-Grid-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --norm_adv --trxl_memory_length 119 --total_timesteps 100000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MortarMayhem-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 275" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MysteryPath-Grid-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 96 --total_timesteps 100000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MysteryPath-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids SearingSpotlights-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Endless-SearingSpotlights-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256 --total_timesteps 350000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Endless-MortarMayhem-v0 Endless-MysteryPath-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256 --total_timesteps 350000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from dataclasses import dataclass | ||
|
||
import gymnasium as gym | ||
import torch | ||
import tyro | ||
from ppo_trxl import Agent, make_env | ||
|
||
|
||
@dataclass | ||
class Args: | ||
hub: bool = False | ||
"""whether to load the model from the huggingface hub or from the local disk""" | ||
name: str = "Endless-MortarMayhem-v0_12.nn" | ||
"""path to the model file""" | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse command line arguments and retrieve model path | ||
cli_args = tyro.cli(Args) | ||
if cli_args.hub: | ||
try: | ||
from huggingface_hub import hf_hub_download | ||
|
||
path = hf_hub_download(repo_id="LilHairdy/cleanrl_memory_gym", filename=cli_args.name) | ||
except: | ||
raise RuntimeError( | ||
"Cannot load model from the huggingface hub. Please install the huggingface_hub pypi package and verify the model name. You can also download the model from the hub manually and load it from disk." | ||
) | ||
else: | ||
path = cli_args.name | ||
|
||
# Load the pre-trained model and the original args used to train it | ||
checkpoint = torch.load(path) | ||
args = checkpoint["args"] | ||
args = type("Args", (), args) | ||
|
||
# Init environment and reset | ||
env = make_env(args.env_id, 0, False, "", "human")() | ||
obs, _ = env.reset() | ||
env.render() | ||
|
||
# Determine maximum episode steps | ||
max_episode_steps = env.spec.max_episode_steps | ||
if not max_episode_steps: | ||
max_episode_steps = env.max_episode_steps | ||
if max_episode_steps <= 0: | ||
max_episode_steps = 1024 # Memory Gym envs have max_episode_steps set to -1 | ||
# May episode impacts positional encoding, so make sure to set this accordingly | ||
|
||
# Setup agent and load its model parameters | ||
action_space_shape = ( | ||
(env.action_space.n,) if isinstance(env.action_space, gym.spaces.Discrete) else tuple(env.action_space.nvec) | ||
) | ||
agent = Agent(args, env.observation_space, action_space_shape, max_episode_steps) | ||
agent.load_state_dict(checkpoint["model_weights"]) | ||
|
||
# Setup Transformer-XL memory, mask and indices | ||
memory = torch.zeros((1, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32) | ||
memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1) | ||
repetitions = torch.repeat_interleave( | ||
torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0 | ||
).long() | ||
memory_indices = torch.stack( | ||
[torch.arange(i, i + args.trxl_memory_length) for i in range(max_episode_steps - args.trxl_memory_length + 1)] | ||
).long() | ||
memory_indices = torch.cat((repetitions, memory_indices)) | ||
|
||
# Run episode | ||
done = False | ||
t = 0 | ||
while not done: | ||
# Prepare observation and memory | ||
obs = torch.Tensor(obs).unsqueeze(0) | ||
memory_window = memory[0, memory_indices[t].unsqueeze(0)] | ||
t_ = max(0, min(t, args.trxl_memory_length - 1)) | ||
mask = memory_mask[t_].unsqueeze(0) | ||
indices = memory_indices[t].unsqueeze(0) | ||
# Forward agent | ||
action, _, _, _, new_memory = agent.get_action_and_value(obs, memory_window, mask, indices) | ||
memory[:, t] = new_memory | ||
# Step | ||
obs, reward, termination, truncation, info = env.step(action.cpu().squeeze().numpy()) | ||
env.render() | ||
done = termination or truncation | ||
t += 1 | ||
|
||
if "r" in info["episode"].keys(): | ||
print(f"Episode return: {info['episode']['r'][0]}, Episode length: {info['episode']['l'][0]}") | ||
else: | ||
print(f"Episode return: {info['reward']}, Episode length: {info['length']}") | ||
env.close() |
Oops, something went wrong.