Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add policy_utils #279

Merged
merged 8 commits into from
Jul 25, 2023
110 changes: 110 additions & 0 deletions compiler_opt/es/policy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util functions to create and edit a tf_agent policy."""

import gin
import numpy as np
import numpy.typing as npt
import tensorflow as tf
from typing import Protocol, Sequence

from compiler_opt.rl import policy_saver, registry
from tf_agents.networks import network
from tf_agents.policies import actor_policy, greedy_policy, tf_policy


class HasModelVariables(Protocol):
model_variables: Sequence[tf.Variable]


# TODO(abenalaast): Issue #280
@gin.configurable(module='policy_utils')
def create_actor_policy(actor_network_ctor: network.DistributionNetwork,
greedy: bool = False) -> tf_policy.TFPolicy:
"""Creates an actor policy."""
problem_config = registry.get_configuration()
time_step_spec, action_spec = problem_config.get_signature_spec()
layers = tf.nest.map_structure(
problem_config.get_preprocessing_layer_creator(),
time_step_spec.observation)

actor_network = actor_network_ctor(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

if greedy:
policy = greedy_policy.GreedyPolicy(policy)

return policy


def get_vectorized_parameters_from_policy(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc strings please (for all of them)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc strings have been added

policy: 'tf_policy.TFPolicy | HasModelVariables'
) -> npt.NDArray[np.float32]:
"""Returns a policy's variable values as a single np array."""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
f'Passed policy: {policy}')

parameters = [var.numpy().flatten() for var in variables]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you have a unit test to make sure that a TFPolicy and its loaded SavedModel have identical ordering of variables? (it's sufficient to check that the float values in parameters are approximately identical using np.testing.assert_allclose or similar)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new test for this. Please check to make sure I understood correctly. Thanks

parameters = np.concatenate(parameters, axis=0)
return parameters


def set_vectorized_parameters_for_policy(
policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32]) -> None:
"""Separates values in parameters into the policy's shapes
and sets the policy variables to those values"""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif hasattr(policy, 'model_variables'):
variables = policy.model_variables
else:
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
f'Passed policy: {policy}')

param_pos = 0
for variable in variables:
shape = tf.shape(variable).numpy()
num_elems = np.prod(shape)
param = np.reshape(parameters[param_pos:param_pos + num_elems], shape)
variable.assign(param)
param_pos += num_elems
if param_pos != len(parameters):
raise ValueError(
f'Parameter dimensions are not matched! Expected {len(parameters)} '
f'but only found {param_pos}.')


def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32], save_folder: str,
policy_name: str) -> None:
"""Assigns a policy the name policy_name
and saves it to the directory of save_folder
with the values in parameters."""
set_vectorized_parameters_for_policy(policy, parameters)
saver = policy_saver.PolicySaver({policy_name: policy})
saver.save(save_folder)
264 changes: 264 additions & 0 deletions compiler_opt/es/policy_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for policy_utils."""

from absl.testing import absltest
import numpy as np
import os
import tensorflow as tf
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import actor_policy, tf_policy

from compiler_opt.es import policy_utils
from compiler_opt.rl import policy_saver, registry
from compiler_opt.rl.inlining import config as inlining_config
from compiler_opt.rl.inlining import InliningConfig
from compiler_opt.rl.regalloc import config as regalloc_config
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network


class ConfigTest(absltest.TestCase):
mtrofin marked this conversation as resolved.
Show resolved Hide resolved

# TODO(abenalaast): Issue #280
def test_inlining_config(self):
problem_config = registry.get_configuration(implementation=InliningConfig)
time_step_spec, action_spec = problem_config.get_signature_spec()
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
creator = inlining_config.get_observation_processing_layer_creator(
quantile_file_dir=quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False)
layers = tf.nest.map_structure(creator, time_step_spec.observation)

actor_network = actor_distribution_network.ActorDistributionNetwork(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(64, 64, 64, 64),
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

self.assertIsNotNone(policy)
self.assertIsInstance(
policy._actor_network, # pylint: disable=protected-access
actor_distribution_network.ActorDistributionNetwork)

# TODO(abenalaast): Issue #280
def test_regalloc_config(self):
problem_config = registry.get_configuration(
implementation=RegallocEvictionConfig)
time_step_spec, action_spec = problem_config.get_signature_spec()
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'regalloc', 'vocab')
creator = regalloc_config.get_observation_processing_layer_creator(
quantile_file_dir=quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False)
layers = tf.nest.map_structure(creator, time_step_spec.observation)

actor_network = regalloc_network.RegAllocNetwork(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(64, 64, 64, 64),
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

self.assertIsNotNone(policy)
self.assertIsInstance(
policy._actor_network, # pylint: disable=protected-access
regalloc_network.RegAllocNetwork)


class VectorTest(absltest.TestCase):

expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64),
(64, 64), (64), (64, 2), (2)]
expected_length_of_a_perturbation = sum(
np.prod(shape) for shape in expected_variable_shapes)
params = np.arange(expected_length_of_a_perturbation, dtype=np.float32)
POLICY_NAME = 'test_policy_name'

# TODO(abenalaast): Issue #280
def test_set_vectorized_parameters_for_policy(self):
# create a policy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 high level questions:

  • can we decouple these tests from registry and all that
  • can we test the 2 supported scenarios: TFAgent and tf.Module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have to look into other ways of creating a policy in order to allow decoupling. In regards to the tests, I have added sections to test loaded policies now. Debugging has revealed that the loaded policy is not an instance of tf.Module but rather one of AutoTrackable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok - could you also add a reference to #280 over each test, easier to avoid forgetting

problem_config = registry.get_configuration(implementation=InliningConfig)
time_step_spec, action_spec = problem_config.get_signature_spec()
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
creator = inlining_config.get_observation_processing_layer_creator(
quantile_file_dir=quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False)
layers = tf.nest.map_structure(creator, time_step_spec.observation)

actor_network = actor_distribution_network.ActorDistributionNetwork(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(64, 64, 64, 64),
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

# save the policy
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
testing_path = self.create_tempdir()
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
saver.save(policy_save_path)

# set the values of the policy variables
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
# iterate through variables and check their shapes and values
# deep copy params in order to destructively iterate over values
expected_values = [*VectorTest.params]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment that we want to destructively go over the expected values, hence the deep copy.

for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable
self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i])
variable_values = variable.numpy().flatten()
np.testing.assert_array_almost_equal(
expected_values[:len(variable_values)], variable_values)
expected_values = expected_values[len(variable_values):]
# all values in the copy should have been removed at this point
self.assertEmpty(expected_values)

# get saved model to test a loaded policy
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
sm = tf.saved_model.load(load_path)
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params)
# deep copy params in order to destructively iterate over values
expected_values = [*VectorTest.params]
for i, variable in enumerate(sm.model_variables):
self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i])
variable_values = variable.numpy().flatten()
np.testing.assert_array_almost_equal(
expected_values[:len(variable_values)], variable_values)
expected_values = expected_values[len(variable_values):]
# all values in the copy should have been removed at this point
self.assertEmpty(expected_values)

# TODO(abenalaast): Issue #280
def test_get_vectorized_parameters_from_policy(self):
# create a policy
problem_config = registry.get_configuration(implementation=InliningConfig)
time_step_spec, action_spec = problem_config.get_signature_spec()
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
creator = inlining_config.get_observation_processing_layer_creator(
quantile_file_dir=quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False)
layers = tf.nest.map_structure(creator, time_step_spec.observation)

actor_network = actor_distribution_network.ActorDistributionNetwork(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(64, 64, 64, 64),
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

# save the policy
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
testing_path = self.create_tempdir()
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
saver.save(policy_save_path)

# functionality verified in previous test
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
# vectorize and check if the outcome is the same as the start
output = policy_utils.get_vectorized_parameters_from_policy(policy)
np.testing.assert_array_almost_equal(output, VectorTest.params)

# get saved model to test a loaded policy
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
sm = tf.saved_model.load(load_path)
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params)
# vectorize and check if the outcome is the same as the start
output = policy_utils.get_vectorized_parameters_from_policy(sm)
np.testing.assert_array_almost_equal(output, VectorTest.params)

# TODO(abenalaast): Issue #280
def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
# create a policy
problem_config = registry.get_configuration(implementation=InliningConfig)
time_step_spec, action_spec = problem_config.get_signature_spec()
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
creator = inlining_config.get_observation_processing_layer_creator(
quantile_file_dir=quantile_file_dir,
with_sqrt=False,
with_z_score_normalization=False)
layers = tf.nest.map_structure(creator, time_step_spec.observation)

actor_network = actor_distribution_network.ActorDistributionNetwork(
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
preprocessing_layers=layers,
preprocessing_combiner=tf.keras.layers.Concatenate(),
fc_layer_params=(64, 64, 64, 64),
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu)

policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=actor_network)

# save the policy
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
testing_path = self.create_tempdir()
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
saver.save(policy_save_path)

# set the values of the variables
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
# save the changes
saver.save(policy_save_path)
# vectorize the tfpolicy
tf_params = policy_utils.get_vectorized_parameters_from_policy(policy)

# get loaded policy
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
sm = tf.saved_model.load(load_path)
# vectorize the loaded policy
loaded_params = policy_utils.get_vectorized_parameters_from_policy(sm)

# assert that they result in the same order of values
np.testing.assert_array_almost_equal(tf_params, loaded_params)


if __name__ == '__main__':
absltest.main()