Skip to content

Commit

Permalink
Merged in feature/empirical_norm (pull request #7)
Browse files Browse the repository at this point in the history
Feature/empirical norm

Approved-by: Nikita Rudin
Approved-by: Marko Bjelonic
  • Loading branch information
fabianje authored and Nikita Rudin committed May 3, 2022
2 parents b9c678e + 9498e29 commit 7950a59
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 1 deletion.
9 changes: 8 additions & 1 deletion rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from torch.distributions import Normal
from rsl_rl.modules.normalizer import EmpiricalNormalization


class ActorCritic(nn.Module):
Expand All @@ -19,6 +20,7 @@ def __init__(
critic_hidden_dims=[256, 256, 256],
activation="elu",
init_noise_std=1.0,
update_obs_norm=True,
**kwargs,
):
if kwargs:
Expand All @@ -27,14 +29,16 @@ def __init__(
+ str([key for key in kwargs.keys()])
)
super(ActorCritic, self).__init__()

activation = get_activation(activation)

mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs

# Policy
actor_layers = []
actor_layers.append(
EmpiricalNormalization(shape=[mlp_input_dim_a], update_obs_norm=update_obs_norm, until=1.0e8)
)
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for layer_index in range(len(actor_hidden_dims)):
Expand All @@ -47,6 +51,9 @@ def __init__(

# Value function
critic_layers = []
critic_layers.append(
EmpiricalNormalization(shape=[mlp_input_dim_c], update_obs_norm=update_obs_norm, until=1.0e8)
)
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for layer_index in range(len(critic_hidden_dims)):
Expand Down
142 changes: 142 additions & 0 deletions rsl_rl/modules/normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# MIT License
#
# Copyright (c) 2020 Preferred Networks, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np

import torch
from torch import nn


class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values.
Args:
shape (int or tuple of int): Shape of input values except batch axis.
batch_axis (int): Batch axis.
eps (float): Small value for stability.
dtype (dtype): Dtype of input values.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
exceeds it.
update_obs_norm (bool): If true, learns updates mean and variance
"""

def __init__(
self,
shape,
batch_axis=0,
eps=1e-2,
dtype=np.float32,
until=None,
clip_threshold=None,
update_obs_norm=True,
):
super(EmpiricalNormalization, self).__init__()
dtype = np.dtype(dtype)
self.batch_axis = batch_axis
self.eps = eps
self.until = until
self.clip_threshold = clip_threshold
self.register_buffer(
"_mean",
torch.tensor(np.expand_dims(np.zeros(shape, dtype=dtype), batch_axis)),
)
self.register_buffer(
"_var",
torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)),
)
self.register_buffer("count", torch.tensor(0))
self.in_features = shape[0]

# cache
self._cached_std_inverse = torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis))
self._is_std_cached = False
self._is_training = update_obs_norm

@property
def mean(self):
return torch.squeeze(self._mean, self.batch_axis).clone()

@property
def std(self):
return torch.sqrt(torch.squeeze(self._var, self.batch_axis)).clone()

@property
def _std_inverse(self):
if self._is_std_cached is False:
self._cached_std_inverse = (self._var + self.eps) ** -0.5

return self._cached_std_inverse

@torch.jit.unused
@torch.no_grad()
def experience(self, x):
"""Learn input values without computing the output values of them"""

if self.until is not None:
if self.count >= self.until:
return

count_x = x.shape[self.batch_axis]
if count_x == 0:
return

self.count += count_x
rate = count_x / self.count.float()
assert rate > 0
assert rate <= 1

var_x = torch.var(x, dim=self.batch_axis, unbiased=False, keepdim=True)
mean_x = torch.mean(x, dim=self.batch_axis, keepdim=True)
delta_mean = mean_x - self._mean
self._mean += rate * delta_mean
self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean))

# clear cache
self._is_std_cached = False

def forward(self, x):
"""Normalize mean and variance of values based on emprical values.
Args:
x (ndarray or Variable): Input values
Returns:
ndarray or Variable: Normalized output values
"""

if self._is_training:
self.experience(x)

if not x.is_cuda:
self._is_std_cached = False
normalized = (x - self._mean) * self._std_inverse
if self.clip_threshold is not None:
normalized = torch.clamp(normalized, -self.clip_threshold, self.clip_threshold)
if not x.is_cuda:
self._is_std_cached = False
return normalized

@torch.jit.unused
def inverse(self, y):
std = torch.sqrt(self._var + self.eps)
return y * std + self._mean

def load_numpy(self, mean, var, count, device="cpu"):
self._mean = torch.from_numpy(np.expand_dims(mean, self.batch_axis)).to(device)
self._var = torch.from_numpy(np.expand_dims(var, self.batch_axis)).to(device)
self.count = torch.tensor(count).to(device)

0 comments on commit 7950a59

Please sign in to comment.