Skip to content

Commit

Permalink
Updated Proximal Policy Optimization example for Keras v3 (keras-team…
Browse files Browse the repository at this point in the history
…#1783)

* Updated Proximal Policy Optimization example for Keras v3

* updated script and added generated files
  • Loading branch information
lpizzinidev authored Mar 13, 2024
1 parent 21fe555 commit 901487f
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 131 deletions.
78 changes: 43 additions & 35 deletions examples/rl/ipynb/ppo_cartpole.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"\n",
"**Author:** [Ilias Chrysovergis](https://twitter.com/iliachry)<br>\n",
"**Date created:** 2021/06/24<br>\n",
"**Last modified:** 2021/06/24<br>\n",
"**Description:** Implementation of a Proximal Policy Optimization agent for the CartPole-v0 environment."
"**Last modified:** 2024/03/12<br>\n",
"**Description:** Implementation of a Proximal Policy Optimization agent for the CartPole-v1 environment."
]
},
{
Expand All @@ -22,9 +22,9 @@
"source": [
"## Introduction\n",
"\n",
"This code example solves the CartPole-v0 environment using a Proximal Policy Optimization (PPO) agent.\n",
"This code example solves the CartPole-v1 environment using a Proximal Policy Optimization (PPO) agent.\n",
"\n",
"### CartPole-v0\n",
"### CartPole-v1\n",
"\n",
"A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.\n",
"The system is controlled by applying a force of +1 or -1 to the cart.\n",
Expand All @@ -33,7 +33,7 @@
"The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.\n",
"After 200 steps the episode ends. Thus, the highest return we can get is equal to 200.\n",
"\n",
"[CartPole-v0](https://gym.openai.com/envs/CartPole-v0/)\n",
"[CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/)\n",
"\n",
"### Proximal Policy Optimization\n",
"\n",
Expand All @@ -47,7 +47,7 @@
"\n",
"![Algorithm](https://i.imgur.com/rd5tda1.png)\n",
"\n",
"- [PPO Original Paper](https://arxiv.org/pdf/1707.06347.pdf)\n",
"- [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)\n",
"- [OpenAI Spinning Up docs - PPO](https://spinningup.openai.com/en/latest/algorithms/ppo.html)\n",
"\n",
"### Note\n",
Expand All @@ -70,7 +70,7 @@
"\n",
"1. `numpy` for n-dimensional arrays\n",
"2. `tensorflow` and `keras` for building the deep RL PPO agent\n",
"3. `gym` for getting everything we need about the environment\n",
"3. `gymnasium` for getting everything we need about the environment\n",
"4. `scipy.signal` for calculating the discounted cumulative sums of vectors"
]
},
Expand All @@ -82,13 +82,17 @@
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import keras\n",
"from keras import layers\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import gym\n",
"import scipy.signal\n",
"import time"
"import gymnasium as gym\n",
"import scipy.signal"
]
},
{
Expand Down Expand Up @@ -173,7 +177,7 @@
" )\n",
"\n",
"\n",
"def mlp(x, sizes, activation=tf.tanh, output_activation=None):\n",
"def mlp(x, sizes, activation=keras.activations.tanh, output_activation=None):\n",
" # Build a feedforward neural network\n",
" for size in sizes[:-1]:\n",
" x = layers.Dense(units=size, activation=activation)(x)\n",
Expand All @@ -182,18 +186,23 @@
"\n",
"def logprobabilities(logits, a):\n",
" # Compute the log-probabilities of taking actions a by using the logits (i.e. the output of the actor)\n",
" logprobabilities_all = tf.nn.log_softmax(logits)\n",
" logprobability = tf.reduce_sum(\n",
" tf.one_hot(a, num_actions) * logprobabilities_all, axis=1\n",
" logprobabilities_all = keras.ops.log_softmax(logits)\n",
" logprobability = keras.ops.sum(\n",
" keras.ops.one_hot(a, num_actions) * logprobabilities_all, axis=1\n",
" )\n",
" return logprobability\n",
"\n",
"\n",
"seed_generator = keras.random.SeedGenerator(1337)\n",
"\n",
"\n",
"# Sample action from actor\n",
"@tf.function\n",
"def sample_action(observation):\n",
" logits = actor(observation)\n",
" action = tf.squeeze(tf.random.categorical(logits, 1), axis=1)\n",
" action = keras.ops.squeeze(\n",
" keras.random.categorical(logits, 1, seed=seed_generator), axis=1\n",
" )\n",
" return logits, action\n",
"\n",
"\n",
Expand All @@ -202,37 +211,36 @@
"def train_policy(\n",
" observation_buffer, action_buffer, logprobability_buffer, advantage_buffer\n",
"):\n",
"\n",
" with tf.GradientTape() as tape: # Record operations for automatic differentiation.\n",
" ratio = tf.exp(\n",
" ratio = keras.ops.exp(\n",
" logprobabilities(actor(observation_buffer), action_buffer)\n",
" - logprobability_buffer\n",
" )\n",
" min_advantage = tf.where(\n",
" min_advantage = keras.ops.where(\n",
" advantage_buffer > 0,\n",
" (1 + clip_ratio) * advantage_buffer,\n",
" (1 - clip_ratio) * advantage_buffer,\n",
" )\n",
"\n",
" policy_loss = -tf.reduce_mean(\n",
" tf.minimum(ratio * advantage_buffer, min_advantage)\n",
" policy_loss = -keras.ops.mean(\n",
" keras.ops.minimum(ratio * advantage_buffer, min_advantage)\n",
" )\n",
" policy_grads = tape.gradient(policy_loss, actor.trainable_variables)\n",
" policy_optimizer.apply_gradients(zip(policy_grads, actor.trainable_variables))\n",
"\n",
" kl = tf.reduce_mean(\n",
" kl = keras.ops.mean(\n",
" logprobability_buffer\n",
" - logprobabilities(actor(observation_buffer), action_buffer)\n",
" )\n",
" kl = tf.reduce_sum(kl)\n",
" kl = keras.ops.sum(kl)\n",
" return kl\n",
"\n",
"\n",
"# Train the value function by regression on mean-squared error\n",
"@tf.function\n",
"def train_value_function(observation_buffer, return_buffer):\n",
" with tf.GradientTape() as tape: # Record operations for automatic differentiation.\n",
" value_loss = tf.reduce_mean((return_buffer - critic(observation_buffer)) ** 2)\n",
" value_loss = keras.ops.mean((return_buffer - critic(observation_buffer)) ** 2)\n",
" value_grads = tape.gradient(value_loss, critic.trainable_variables)\n",
" value_optimizer.apply_gradients(zip(value_grads, critic.trainable_variables))\n",
""
Expand Down Expand Up @@ -291,28 +299,27 @@
"source": [
"# Initialize the environment and get the dimensionality of the\n",
"# observation space and the number of possible actions\n",
"env = gym.make(\"CartPole-v0\")\n",
"env = gym.make(\"CartPole-v1\")\n",
"observation_dimensions = env.observation_space.shape[0]\n",
"num_actions = env.action_space.n\n",
"\n",
"# Initialize the buffer\n",
"buffer = Buffer(observation_dimensions, steps_per_epoch)\n",
"\n",
"# Initialize the actor and the critic as keras models\n",
"observation_input = keras.Input(shape=(observation_dimensions,), dtype=tf.float32)\n",
"logits = mlp(observation_input, list(hidden_sizes) + [num_actions], tf.tanh, None)\n",
"observation_input = keras.Input(shape=(observation_dimensions,), dtype=\"float32\")\n",
"logits = mlp(observation_input, list(hidden_sizes) + [num_actions])\n",
"actor = keras.Model(inputs=observation_input, outputs=logits)\n",
"value = tf.squeeze(\n",
" mlp(observation_input, list(hidden_sizes) + [1], tf.tanh, None), axis=1\n",
")\n",
"value = keras.ops.squeeze(mlp(observation_input, list(hidden_sizes) + [1]), axis=1)\n",
"critic = keras.Model(inputs=observation_input, outputs=value)\n",
"\n",
"# Initialize the policy and the value function optimizers\n",
"policy_optimizer = keras.optimizers.Adam(learning_rate=policy_learning_rate)\n",
"value_optimizer = keras.optimizers.Adam(learning_rate=value_function_learning_rate)\n",
"\n",
"# Initialize the observation, episode return and episode length\n",
"observation, episode_return, episode_length = env.reset(), 0, 0"
"observation, _ = env.reset()\n",
"episode_return, episode_length = 0, 0"
]
},
{
Expand Down Expand Up @@ -347,7 +354,7 @@
" # Get the logits, action, and take one step in the environment\n",
" observation = observation.reshape(1, -1)\n",
" logits, action = sample_action(observation)\n",
" observation_new, reward, done, _ = env.step(action[0].numpy())\n",
" observation_new, reward, done, _, _ = env.step(action[0].numpy())\n",
" episode_return += reward\n",
" episode_length += 1\n",
"\n",
Expand All @@ -369,7 +376,8 @@
" sum_return += episode_return\n",
" sum_length += episode_length\n",
" num_episodes += 1\n",
" observation, episode_return, episode_length = env.reset(), 0, 0\n",
" observation, _ = env.reset()\n",
" episode_return, episode_length = 0, 0\n",
"\n",
" # Get values from the buffer\n",
" (\n",
Expand Down Expand Up @@ -422,7 +430,7 @@
}
],
"metadata": {
"accelerator": "GPU",
"accelerator": "None",
"colab": {
"collapsed_sections": [],
"name": "ppo_cartpole",
Expand Down
Loading

0 comments on commit 901487f

Please sign in to comment.