-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add policy_utils #279
add policy_utils #279
Changes from 1 commit
c0c3e4d
35425e8
abe2201
96dcfe2
f9d098e
9763737
485b8b1
74b7605
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
############################################################################### | ||
# | ||
# | ||
# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas | ||
# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed | ||
# below: | ||
# | ||
# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, | ||
# V. Sindhwani, NeurIPS 2017 | ||
# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. | ||
# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 | ||
# "Structured Evolution with Compact Architectures for Scalable Policy | ||
# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. | ||
# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 | ||
# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox | ||
# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. | ||
# Sindhwani, NeurIPS 2019 | ||
# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot | ||
# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, | ||
# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 | ||
# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. | ||
# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. | ||
# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, | ||
# https://arxiv.org/abs/2306.08205 | ||
# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. | ||
# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, | ||
# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, | ||
# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, | ||
# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, | ||
# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, | ||
# to be presented at RSS 2023 | ||
############################################################################### | ||
"""Util functions to create and edit a tf_agent policy.""" | ||
|
||
import gin | ||
import numpy as np | ||
import numpy.typing as npt | ||
import tensorflow as tf | ||
from typing import Union | ||
|
||
from tf_agents.networks import network | ||
from tf_agents.policies import actor_policy, greedy_policy, tf_policy | ||
from compiler_opt.rl import policy_saver, registry | ||
|
||
|
||
@gin.configurable(module='policy_utils') | ||
def create_actor_policy(actor_network_ctor: network.DistributionNetwork, | ||
greedy: bool = False) -> tf_policy.TFPolicy: | ||
"""Creates an actor policy.""" | ||
problem_config = registry.get_configuration() | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
layers = tf.nest.map_structure( | ||
problem_config.get_preprocessing_layer_creator(), | ||
time_step_spec.observation) | ||
|
||
actor_network = actor_network_ctor( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
if greedy: | ||
policy = greedy_policy.GreedyPolicy(policy) | ||
|
||
return policy | ||
|
||
|
||
def get_vectorized_parameters_from_policy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc strings please (for all of them) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doc strings have been added |
||
policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]: | ||
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
elif policy.model_variables: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd argue for |
||
variables = policy.model_variables | ||
|
||
parameters = [var.numpy().flatten() for var in variables] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you have a unit test to make sure that a TFPolicy and its loaded SavedModel have identical ordering of variables? (it's sufficient to check that the float values in parameters are approximately identical using np.testing.assert_allclose or similar) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a new test for this. Please check to make sure I understood correctly. Thanks |
||
parameters = np.concatenate(parameters, axis=0) | ||
return parameters | ||
|
||
|
||
def set_vectorized_parameters_for_policy( | ||
policy: Union[tf_policy.TFPolicy, | ||
tf.Module], parameters: npt.NDArray[np.float32]) -> None: | ||
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
else: | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for consistency, whatever you do here should match whatever we do on line 91. Come to think of it, I think the python preference is to raise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The checks have been changed to be the same now--check for TFPolicy, check for model_variables, else raise ValueError |
||
getattr(policy, 'model_variables') | ||
except AttributeError as e: | ||
raise TypeError('policy must be a TFPolicy or a loaded SavedModel') from e | ||
variables = policy.model_variables | ||
|
||
param_pos = 0 | ||
for variable in variables: | ||
shape = tf.shape(variable).numpy() | ||
num_ele = np.prod(shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it is a bit awkward, I changed it to num_elems now |
||
param = np.reshape(parameters[param_pos:param_pos + num_ele], shape) | ||
variable.assign(param) | ||
param_pos += num_ele | ||
if param_pos != len(parameters): | ||
raise ValueError( | ||
f'Parameter dimensions are not matched! Expected {len(parameters)} ' | ||
'but only found {param_pos}.') | ||
|
||
|
||
def save_policy(policy: tf_policy.TFPolicy, parameters: npt.NDArray[np.float32], | ||
save_folder: str, policy_name: str) -> None: | ||
set_vectorized_parameters_for_policy(policy, parameters) | ||
saver = policy_saver.PolicySaver({policy_name: policy}) | ||
saver.save(save_folder) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
############################################################################### | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment re. this bit of the docstring |
||
# | ||
# | ||
# This is a port of the code by Krzysztof Choromanski, Deepali Jain and Vikas | ||
# Sindhwani, based on the portfolio of Blackbox optimization algorithms listed | ||
# below: | ||
# | ||
# "On Blackbox Backpropagation and Jacobian Sensing"; K. Choromanski, | ||
# V. Sindhwani, NeurIPS 2017 | ||
# "Optimizing Simulations with Noise-Tolerant Structured Exploration"; K. | ||
# Choromanski, A. Iscen, V. Sindhwani, J. Tan, E. Coumans, ICRA 2018 | ||
# "Structured Evolution with Compact Architectures for Scalable Policy | ||
# Optimization"; K. Choromanski, M. Rowland, V. Sindhwani, R. Turner, A. | ||
# Weller, ICML 2018, https://arxiv.org/abs/1804.02395 | ||
# "From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox | ||
# Optimization"; K. Choromanski, A. Pacchiano, J. Parker-Holder, Y. Tang, V. | ||
# Sindhwani, NeurIPS 2019 | ||
# "i-Sim2Real: Reinforcement Learning on Robotic Policies in Tight Human-Robot | ||
# Interaction Loops"; L. Graesser, D. D'Ambrosio, A. Singh, A. Bewley, D. Jain, | ||
# K. Choromanski, P. Sanketi , CoRL 2022, https://arxiv.org/abs/2207.06572 | ||
# "Agile Catching with Whole-Body MPC and Blackbox Policy Learning"; S. | ||
# Abeyruwan, A. Bewley, N. Boffi, K. Choromanski, D. D'Ambrosio, D. Jain, P. | ||
# Sanketi, A. Shankar, V. Sindhwani, S. Singh, J. Slotine, S. Tu, L4DC, | ||
# https://arxiv.org/abs/2306.08205 | ||
# "Robotic Table Tennis: A Case Study into a High Speed Learning System"; A. | ||
# Bewley, A. Shankar, A. Iscen, A. Singh, C. Lynch, D. D'Ambrosio, D. Jain, | ||
# E. Coumans, G. Versom, G. Kouretas, J. Abelian, J. Boyd, K. Oslund, | ||
# K. Reymann, K. Choromanski, L. Graesser, M. Ahn, N. Jaitly, N. Lazic, | ||
# P. Sanketi, P. Xu, P. Sermanet, R. Mahjourian, S. Abeyruwan, S. Kataoka, | ||
# S. Moore, T. Nguyen, T. Ding, V. Sindhwani, V. Vanhoucke, W. Gao, Y. Kuang, | ||
# to be presented at RSS 2023 | ||
############################################################################### | ||
"""Tests for policy_utils.""" | ||
|
||
from absl.testing import absltest | ||
import numpy as np | ||
import os | ||
import tensorflow as tf | ||
from tf_agents.networks import actor_distribution_network | ||
from tf_agents.policies import actor_policy | ||
|
||
from compiler_opt.es import policy_utils | ||
from compiler_opt.rl import policy_saver, registry | ||
from compiler_opt.rl.inlining import InliningConfig | ||
from compiler_opt.rl.inlining import config as inlining_config | ||
from compiler_opt.rl.regalloc import config as regalloc_config | ||
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network | ||
|
||
|
||
class ConfigTest(absltest.TestCase): | ||
mtrofin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def test_inlining_config(self): | ||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
self.assertIsNotNone(policy) | ||
self.assertIsInstance( | ||
policy._actor_network, # pylint: disable=protected-access | ||
actor_distribution_network.ActorDistributionNetwork) | ||
|
||
def test_regalloc_config(self): | ||
problem_config = registry.get_configuration( | ||
implementation=RegallocEvictionConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = regalloc_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/regalloc/vocab', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = regalloc_network.RegAllocNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
self.assertIsNotNone(policy) | ||
self.assertIsInstance( | ||
policy._actor_network, # pylint: disable=protected-access | ||
regalloc_network.RegAllocNetwork) | ||
|
||
|
||
class VectorTest(absltest.TestCase): | ||
|
||
def test_set_vectorized_parameters_for_policy(self): | ||
# create a policy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 high level questions:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will have to look into other ways of creating a policy in order to allow decoupling. In regards to the tests, I have added sections to test loaded policies now. Debugging has revealed that the loaded policy is not an instance of tf.Module but rather one of AutoTrackable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok - could you also add a reference to #280 over each test, easier to avoid forgetting |
||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
|
||
# save the policy | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. `os.path.join(testing_path, 'temp_output', 'policy') i.e. don't assume '/' is the separator. also, can we call 'policy' something else, it's a bit confusing how then we add again a 'policy' on line 144 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I made a variable POLICY_NAME and used it for the name in the dict on lines like 126 here for clarity. Should I also change lines with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's fine, we'll remove it later bc #280 anyway. |
||
saver.save(policy_save_path) | ||
|
||
# set the values of the policy variables | ||
length_of_a_perturbation = 17218 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why 17218 - it's the sum of the shapes on line 129, correct? could you move that line above, then calculate length_of_a_perturbation from it, and maybe rename length_of_a... to expected_length_of_a_perturbation - then it's (I'd argue) more clear what's going on. |
||
params = np.arange(length_of_a_perturbation, dtype=np.float32) | ||
policy_utils.set_vectorized_parameters_for_policy(policy, params) | ||
# iterate through variables and check their values | ||
idx = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same idea with idx... and same comment further below about naming. |
||
for variable in policy.variables(): # pylint: disable=not-callable | ||
nums = variable.numpy().flatten() | ||
for num in nums: | ||
if idx != num: | ||
raise AssertionError(f'values at index {idx} do not match') | ||
idx += 1 | ||
|
||
def test_get_vectorized_parameters_from_policy(self): | ||
# create a policy | ||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
|
||
# save the policy | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment about path and names |
||
saver.save(policy_save_path) | ||
|
||
length_of_a_perturbation = 17218 | ||
params = np.arange(length_of_a_perturbation, dtype=np.float32) | ||
# functionality verified in previous test | ||
policy_utils.set_vectorized_parameters_for_policy(policy, params) | ||
# vectorize and check if the outcome is the same as the start | ||
output = policy_utils.get_vectorized_parameters_from_policy(policy) | ||
np.testing.assert_array_almost_equal(output, params) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this specific file needs this - these are general - purpose TF utilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra parts have been removed now