-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Util function to create a tf_agent policy.""" | ||
|
||
import gin | ||
import numpy as np | ||
import numpy.typing as npt | ||
import tensorflow as tf | ||
from typing import Union | ||
|
||
from tf_agents.networks import network | ||
from tf_agents.policies import actor_policy, greedy_policy, tf_policy | ||
from compiler_opt.rl import policy_saver, registry | ||
|
||
|
||
@gin.configurable(module='policy_utils') | ||
def create_actor_policy(actor_network_ctor: network.DistributionNetwork, | ||
greedy: bool = False) -> tf_policy.TFPolicy: | ||
"""Creates an actor policy.""" | ||
problem_config = registry.get_configuration() | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
layers = tf.nest.map_structure( | ||
problem_config.get_preprocessing_layer_creator(), | ||
time_step_spec.observation) | ||
|
||
actor_network = actor_network_ctor( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
if greedy: | ||
policy = greedy_policy.GreedyPolicy(policy) | ||
|
||
return policy | ||
|
||
|
||
def get_vectorized_parameters_from_policy( | ||
policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]: | ||
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
elif policy.model_variables: | ||
variables = policy.model_variables | ||
|
||
parameters = [var.numpy().flatten() for var in variables] | ||
parameters = np.concatenate(parameters, axis=0) | ||
return parameters | ||
|
||
|
||
def set_vectorized_parameters_for_policy( | ||
policy: Union[tf_policy.TFPolicy, | ||
tf.Module], parameters: npt.NDArray[np.float32]) -> None: | ||
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
else: | ||
try: | ||
getattr(policy, 'model_variables') | ||
except AttributeError as e: | ||
raise TypeError('policy must be a TFPolicy or a loaded SavedModel') from e | ||
variables = policy.model_variables | ||
|
||
param_pos = 0 | ||
for variable in variables: | ||
shape = tf.shape(variable).numpy() | ||
num_ele = np.prod(shape) | ||
param = np.reshape(parameters[param_pos:param_pos + num_ele], shape) | ||
variable.assign(param) | ||
param_pos += num_ele | ||
if param_pos != len(parameters): | ||
raise ValueError( | ||
f'Parameter dimensions are not matched! Expected {len(parameters)} ' | ||
'but only found {param_pos}.') | ||
|
||
|
||
def save_policy(policy: tf_policy.TFPolicy, parameters: npt.NDArray[np.float32], | ||
save_folder: str, policy_name: str) -> None: | ||
set_vectorized_parameters_for_policy(policy, parameters) | ||
saver = policy_saver.PolicySaver({policy_name: policy}) | ||
saver.save(save_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Tests for policy_utils.""" | ||
|
||
from absl.testing import absltest | ||
import numpy as np | ||
import os | ||
import tensorflow as tf | ||
from tf_agents.networks import actor_distribution_network | ||
from tf_agents.policies import actor_policy | ||
|
||
from compiler_opt.es import policy_utils | ||
from compiler_opt.rl import policy_saver, registry | ||
from compiler_opt.rl.inlining import InliningConfig | ||
from compiler_opt.rl.inlining import config as inlining_config | ||
from compiler_opt.rl.regalloc import config as regalloc_config | ||
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network | ||
|
||
|
||
class ConfigTest(absltest.TestCase): | ||
|
||
def test_inlining_config(self): | ||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
self.assertIsNotNone(policy) | ||
self.assertIsInstance( | ||
policy._actor_network, # pylint: disable=protected-access | ||
actor_distribution_network.ActorDistributionNetwork) | ||
|
||
def test_regalloc_config(self): | ||
problem_config = registry.get_configuration( | ||
implementation=RegallocEvictionConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = regalloc_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/regalloc/vocab', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = regalloc_network.RegAllocNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
|
||
self.assertIsNotNone(policy) | ||
self.assertIsInstance( | ||
policy._actor_network, # pylint: disable=protected-access | ||
regalloc_network.RegAllocNetwork) | ||
|
||
|
||
class VectorTest(absltest.TestCase): | ||
|
||
def test_set_vectorized_parameters_for_policy(self): | ||
# create a policy | ||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
|
||
# save the policy | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') | ||
saver.save(policy_save_path) | ||
|
||
# set the values of the policy variables | ||
length_of_a_perturbation = 17218 | ||
params = np.arange(length_of_a_perturbation, dtype=np.float32) | ||
policy_utils.set_vectorized_parameters_for_policy(policy, params) | ||
# iterate through variables and check their values | ||
idx = 0 | ||
for variable in policy.variables(): # pylint: disable=not-callable | ||
nums = variable.numpy().flatten() | ||
for num in nums: | ||
if idx != num: | ||
raise AssertionError(f'values at index {idx} do not match') | ||
idx += 1 | ||
|
||
def test_get_vectorized_parameters_from_policy(self): | ||
# create a policy | ||
problem_config = registry.get_configuration(implementation=InliningConfig) | ||
time_step_spec, action_spec = problem_config.get_signature_spec() | ||
creator = inlining_config.get_observation_processing_layer_creator( | ||
quantile_file_dir='compiler_opt/rl/inlining/vocab/', | ||
with_sqrt=False, | ||
with_z_score_normalization=False) | ||
layers = tf.nest.map_structure(creator, time_step_spec.observation) | ||
|
||
actor_network = actor_distribution_network.ActorDistributionNetwork( | ||
input_tensor_spec=time_step_spec.observation, | ||
output_tensor_spec=action_spec, | ||
preprocessing_layers=layers, | ||
preprocessing_combiner=tf.keras.layers.Concatenate(), | ||
fc_layer_params=(64, 64, 64, 64), | ||
dropout_layer_params=None, | ||
activation_fn=tf.keras.activations.relu) | ||
|
||
policy = actor_policy.ActorPolicy( | ||
time_step_spec=time_step_spec, | ||
action_spec=action_spec, | ||
actor_network=actor_network) | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
|
||
# save the policy | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') | ||
saver.save(policy_save_path) | ||
|
||
length_of_a_perturbation = 17218 | ||
params = np.arange(length_of_a_perturbation, dtype=np.float32) | ||
# functionality verified in previous test | ||
policy_utils.set_vectorized_parameters_for_policy(policy, params) | ||
# vectorize and check if the outcome is the same as the start | ||
output = policy_utils.get_vectorized_parameters_from_policy(policy) | ||
np.testing.assert_array_almost_equal(output, params) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |