Skip to content

Commit

Permalink
Merged in fix/flake8 (pull request #8)
Browse files Browse the repository at this point in the history
Fixes all the flake8 complaints

Approved-by: Nikita Rudin
  • Loading branch information
Mayankm96 committed Mar 21, 2022
2 parents 9dd09e7 + 2eeff7c commit b9c678e
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 301 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ black --line-length 120 .
# for checking lints
pip install flake8
flake8 .
```
```
2 changes: 1 addition & 1 deletion rsl_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
# SPDX-License-Identifier: BSD-3-Clause
6 changes: 5 additions & 1 deletion rsl_rl/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause

from .ppo import PPO
"""Implementation of different RL agents."""

from .ppo import PPO

__all__ = ["PPO"]
192 changes: 105 additions & 87 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,36 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause

from datetime import datetime
import os
import time

from gym.spaces import Space
import statistics
from collections import deque

# torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# rsl-rl
from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage


class PPO:
actor_critic: ActorCritic
def __init__(self,
actor_critic,
num_learning_epochs=1,
num_mini_batches=1,
clip_param=0.2,
gamma=0.998,
lam=0.95,
value_loss_coef=1.0,
entropy_coef=0.0,
learning_rate=1e-3,
max_grad_norm=1.0,
use_clipped_value_loss=True,
schedule="fixed",
desired_kl=0.01,
device='cpu',
):

def __init__(
self,
actor_critic,
num_learning_epochs=1,
num_mini_batches=1,
clip_param=0.2,
gamma=0.998,
lam=0.95,
value_loss_coef=1.0,
entropy_coef=0.0,
learning_rate=1e-3,
max_grad_norm=1.0,
use_clipped_value_loss=True,
schedule="fixed",
desired_kl=0.01,
device="cpu",
):

self.device = device

Expand All @@ -45,7 +41,7 @@ def __init__(self,
# PPO components
self.actor_critic = actor_critic
self.actor_critic.to(self.device)
self.storage = None # initialized later
self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()

Expand All @@ -61,11 +57,13 @@ def __init__(self,
self.use_clipped_value_loss = use_clipped_value_loss

def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device)
self.storage = RolloutStorage(
num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
)

def test_mode(self):
self.actor_critic.test()

def train_mode(self):
self.actor_critic.train()

Expand All @@ -82,21 +80,23 @@ def act(self, obs, critic_obs):
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions

def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
if "time_outs" in infos:
self.transition.rewards += self.gamma * torch.squeeze(
self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
)

# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)

def compute_returns(self, last_critic_obs):
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)

def update(self):
Expand All @@ -106,60 +106,78 @@ def update(self):
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:


self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy

# KL
if self.desired_kl != None and self.schedule == 'adaptive':
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)

if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)

for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate


# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
self.clip_param)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()

loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()

# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()

mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
for (
obs_batch,
critic_obs_batch,
actions_batch,
target_values_batch,
advantages_batch,
returns_batch,
old_actions_log_prob_batch,
old_mu_batch,
old_sigma_batch,
hid_states_batch,
masks_batch,
) in generator:

self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
)
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy

# KL
if self.desired_kl is not None and self.schedule == "adaptive":
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
+ (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
/ (2.0 * torch.square(sigma_batch))
- 0.5,
axis=-1,
)
kl_mean = torch.mean(kl)

if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)

for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate

# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
-self.clip_param, self.clip_param
)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()

loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()

# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()

mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()

num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
Expand Down
5 changes: 4 additions & 1 deletion rsl_rl/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Submodule defining the environment definitions."""

from .vec_env import VecEnv
from .vec_env import VecEnv

__all__ = ["VecEnv"]
9 changes: 7 additions & 2 deletions rsl_rl/env/vec_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause

# python
from abc import ABC, abstractmethod
import torch
from typing import Tuple, Union

# torch
import torch


# minimal interface of the environment
class VecEnv(ABC):
"""Abstract class for vectorized environment."""
Expand All @@ -18,7 +22,7 @@ class VecEnv(ABC):
obs_buf: torch.Tensor
rew_buf: torch.Tensor
reset_buf: torch.Tensor
episode_length_buf: torch.Tensor # current episode duration
episode_length_buf: torch.Tensor # current episode duration
extras: dict
device: torch.device

Expand All @@ -29,6 +33,7 @@ class VecEnv(ABC):
@abstractmethod
def get_observations(self) -> torch.Tensor:
pass

@abstractmethod
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
pass
Expand Down
6 changes: 5 additions & 1 deletion rsl_rl/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause

"""Definitions for neural-network components for RL-agents."""

from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent
from .actor_critic_recurrent import ActorCriticRecurrent

__all__ = ["ActorCritic", "ActorCriticRecurrent"]
Loading

0 comments on commit b9c678e

Please sign in to comment.