Skip to content

Commit

Permalink
add tests for loaded policies, revise error handling, add docstrings,…
Browse files Browse the repository at this point in the history
… edit type annotations, remove credit message
  • Loading branch information
salaast committed Jul 19, 2023
1 parent c0c3e4d commit 3bc4fe6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 86 deletions.
71 changes: 24 additions & 47 deletions compiler_opt/es/policy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###############################################################################
#
#
# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas
# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed
# below:
#
# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski,
# V. Sindhwani, NeurIPS 2017
# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K.
# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018
# "Structured Evolution with Compact Architectures for Scalable Policy
# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A.
# Weller, ICML 2018, https://arxiv.org/abs/1804.02395
# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox
# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V.
# Sindhwani, NeurIPS 2019
# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot
# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain,
# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572
# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S.
# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P.
# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC,
# https://arxiv.org/abs/2306.08205
# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A.
# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain,
# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund,
# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic,
# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka,
# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang,
# to be presented at RSS 2023
###############################################################################
"""Util functions to create and edit a tf_agent policy."""

import gin
import numpy as np
import numpy.typing as npt
import tensorflow as tf
from tensorflow.python.trackable import autotrackable
from typing import Union

from tf_agents.networks import network
from tf_agents.policies import actor_policy, greedy_policy, tf_policy
from compiler_opt.rl import policy_saver, registry


# TODO(abenalaast): Issue #280
@gin.configurable(module='policy_utils')
def create_actor_policy(actor_network_ctor: network.DistributionNetwork,
greedy: bool = False) -> tf_policy.TFPolicy:
Expand Down Expand Up @@ -85,44 +54,52 @@ def create_actor_policy(actor_network_ctor: network.DistributionNetwork,


def get_vectorized_parameters_from_policy(
policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]:
policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable]
) -> npt.NDArray[np.float32]:
"""Returns a policy's variable values as a single np array."""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif policy.model_variables:
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError('policy must be a TFPolicy or a loaded SavedModel')

parameters = [var.numpy().flatten() for var in variables]
parameters = np.concatenate(parameters, axis=0)
return parameters


def set_vectorized_parameters_for_policy(
policy: Union[tf_policy.TFPolicy,
tf.Module], parameters: npt.NDArray[np.float32]) -> None:
policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable],
parameters: npt.NDArray[np.float32]) -> None:
"""Separates values in parameters into the policy's shapes
and sets the policy variables to those values"""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
else:
try:
getattr(policy, 'model_variables')
except AttributeError as e:
raise TypeError('policy must be a TFPolicy or a loaded SavedModel') from e
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError('policy must be a TFPolicy or a loaded SavedModel')

param_pos = 0
for variable in variables:
shape = tf.shape(variable).numpy()
num_ele = np.prod(shape)
param = np.reshape(parameters[param_pos:param_pos + num_ele], shape)
num_elems = np.prod(shape)
param = np.reshape(parameters[param_pos:param_pos + num_elems], shape)
variable.assign(param)
param_pos += num_ele
param_pos += num_elems
if param_pos != len(parameters):
raise ValueError(
f'Parameter dimensions are not matched! Expected {len(parameters)} '
'but only found {param_pos}.')


def save_policy(policy: tf_policy.TFPolicy, parameters: npt.NDArray[np.float32],
save_folder: str, policy_name: str) -> None:
def save_policy(policy: Union[tf_policy.TFPolicy, autotrackable.AutoTrackable],
parameters: npt.NDArray[np.float32], save_folder: str,
policy_name: str) -> None:
"""Assigns a policy the name policy_name
and saves it to the directory of save_folder
with the values in parameters."""
set_vectorized_parameters_for_policy(policy, parameters)
saver = policy_saver.PolicySaver({policy_name: policy})
saver.save(save_folder)
64 changes: 29 additions & 35 deletions compiler_opt/es/policy_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###############################################################################
#
#
# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas
# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed
# below:
#
# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski,
# V. Sindhwani, NeurIPS 2017
# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K.
# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018
# "Structured Evolution with Compact Architectures for Scalable Policy
# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A.
# Weller, ICML 2018, https://arxiv.org/abs/1804.02395
# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox
# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V.
# Sindhwani, NeurIPS 2019
# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot
# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain,
# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572
# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S.
# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P.
# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC,
# https://arxiv.org/abs/2306.08205
# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A.
# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain,
# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund,
# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic,
# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka,
# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang,
# to be presented at RSS 2023
###############################################################################
"""Tests for policy_utils."""

from absl.testing import absltest
import numpy as np
import os
import tensorflow as tf
from tensorflow.python.trackable import autotrackable
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import actor_policy
from tf_agents.policies import actor_policy, tf_policy

from compiler_opt.es import policy_utils
from compiler_opt.rl import policy_saver, registry
from compiler_opt.rl.inlining import InliningConfig
from compiler_opt.rl.inlining import config as inlining_config
from compiler_opt.rl.inlining import InliningConfig
from compiler_opt.rl.regalloc import config as regalloc_config
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network


# TODO(abenalaast): Issue #280
class ConfigTest(absltest.TestCase):

def test_inlining_config(self):
Expand Down Expand Up @@ -167,6 +136,21 @@ def test_set_vectorized_parameters_for_policy(self):
raise AssertionError(f'values at index {idx} do not match')
idx += 1

# get saved model to test a loaded policy
sm = tf.saved_model.load(policy_save_path + '/policy')
self.assertIsInstance(sm, autotrackable.AutoTrackable)
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
params = params[::-1]
policy_utils.set_vectorized_parameters_for_policy(sm, params)
val = length_of_a_perturbation - 1
for variable in sm.model_variables:
nums = variable.numpy().flatten()
for num in nums:
if val != num:
raise AssertionError(
f'values at index {length_of_a_perturbation - val} do not match')
val -= 1

def test_get_vectorized_parameters_from_policy(self):
# create a policy
problem_config = registry.get_configuration(implementation=InliningConfig)
Expand Down Expand Up @@ -205,6 +189,16 @@ def test_get_vectorized_parameters_from_policy(self):
output = policy_utils.get_vectorized_parameters_from_policy(policy)
np.testing.assert_array_almost_equal(output, params)

# get saved model to test a loaded policy
sm = tf.saved_model.load(policy_save_path + '/policy')
self.assertIsInstance(sm, autotrackable.AutoTrackable)
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
params = params[::-1]
policy_utils.set_vectorized_parameters_for_policy(sm, params)
# vectorize and check if the outcome is the same as the start
output = policy_utils.get_vectorized_parameters_from_policy(sm)
np.testing.assert_array_almost_equal(output, params)


if __name__ == '__main__':
absltest.main()
4 changes: 2 additions & 2 deletions compiler_opt/rl/policy_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def __init__(self, policy_dict: Dict[str, tf_policy.TFPolicy]):
self._policy_saver_dict: Dict[str, Tuple[
policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
policy_name: (policy_saver.PolicySaver(
policy, batch_size=1, use_nest_path_signatures=False), policy)
for policy_name, policy in policy_dict.items()
policy, batch_size=1, use_nest_path_signatures=False), policy
) for policy_name, policy in policy_dict.items()
}

def _save_policy(self, saver, path):
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/rl/train_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def sequence_example_iterator_fn(seq_ex: List[str]):

# Repeat for num_policy_iterations iterations.
t1 = time.time()
while (llvm_trainer.global_step_numpy() <
num_policy_iterations * num_iterations):
while (llvm_trainer.global_step_numpy()
< num_policy_iterations * num_iterations):
t2 = time.time()
logging.info('Last iteration took: %f', t2 - t1)
t1 = t2
Expand Down

0 comments on commit 3bc4fe6

Please sign in to comment.