-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add policy_utils #279
add policy_utils #279
Conversation
compiler_opt/es/policy_utils.py
Outdated
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
############################################################################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this specific file needs this - these are general - purpose TF utilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra parts have been removed now
return policy | ||
|
||
|
||
def get_vectorized_parameters_from_policy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc strings please (for all of them)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc strings have been added
compiler_opt/es/policy_utils.py
Outdated
policy: Union[tf_policy.TFPolicy, tf.Module]) -> npt.NDArray[np.float32]: | ||
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
elif policy.model_variables: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd argue for else:
and assert the policy has a model_variables
. IIUC it's a bug otherwise (API user error: they either pass in a TFPolicy of a Module)
compiler_opt/es/policy_utils.py
Outdated
if isinstance(policy, tf_policy.TFPolicy): | ||
variables = policy.variables() | ||
else: | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for consistency, whatever you do here should match whatever we do on line 91. Come to think of it, I think the python preference is to raise ValueError
(i.e. not assert
- that's my C++ speaking)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checks have been changed to be the same now--check for TFPolicy, check for model_variables, else raise ValueError
compiler_opt/es/policy_utils.py
Outdated
param_pos = 0 | ||
for variable in variables: | ||
shape = tf.shape(variable).numpy() | ||
num_ele = np.prod(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_elems
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is a bit awkward, I changed it to num_elems now
compiler_opt/es/policy_utils_test.py
Outdated
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
############################################################################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment re. this bit of the docstring
class VectorTest(absltest.TestCase): | ||
|
||
def test_set_vectorized_parameters_for_policy(self): | ||
# create a policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 high level questions:
- can we decouple these tests from registry and all that
- can we test the 2 supported scenarios: TFAgent and tf.Module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will have to look into other ways of creating a policy in order to allow decoupling. In regards to the tests, I have added sections to test loaded policies now. Debugging has revealed that the loaded policy is not an instance of tf.Module but rather one of AutoTrackable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok - could you also add a reference to #280 over each test, easier to avoid forgetting
compiler_opt/es/policy_utils.py
Outdated
elif hasattr(policy, 'model_variables'): | ||
variables = policy.model_variables | ||
else: | ||
raise ValueError('policy must be a TFPolicy or a loaded SavedModel') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include the policy object in the ValueError message so we know what was passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the message now
else: | ||
raise ValueError('policy must be a TFPolicy or a loaded SavedModel') | ||
|
||
parameters = [var.numpy().flatten() for var in variables] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you have a unit test to make sure that a TFPolicy and its loaded SavedModel have identical ordering of variables? (it's sufficient to check that the float values in parameters are approximately identical using np.testing.assert_allclose or similar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a new test for this. Please check to make sure I understood correctly. Thanks
… edit type annotations, remove credit message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some interim comments - I know you were going to look at further decoupling the "value" tests from specific problem solvers ("registry"), but they may be applicable.
compiler_opt/es/policy_utils_test.py
Outdated
saver.save(policy_save_path) | ||
|
||
# set the values of the policy variables | ||
length_of_a_perturbation = 17218 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why 17218 - it's the sum of the shapes on line 129, correct? could you move that line above, then calculate length_of_a_perturbation from it, and maybe rename length_of_a... to expected_length_of_a_perturbation - then it's (I'd argue) more clear what's going on.
compiler_opt/es/policy_utils_test.py
Outdated
idx = 0 | ||
for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable | ||
self.assertEqual(variable.shape, expected_variable_shapes[i]) | ||
nums = variable.numpy().flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/nums/variable_values
compiler_opt/es/policy_utils_test.py
Outdated
for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable | ||
self.assertEqual(variable.shape, expected_variable_shapes[i]) | ||
nums = variable.numpy().flatten() | ||
for num in nums: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/num/variable_value
compiler_opt/es/policy_utils_test.py
Outdated
expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64), | ||
(64, 64), (64), (64, 2), (2)] | ||
# iterate through variables and check their shapes and values | ||
idx = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could say expected_values = range(expected_length_of_a_perturbation), then you don't need idx, you can just check on line 136 something like:
self.assertListEqual(expected_values[:len(variable_values)], variable_values)
expected_values = expected_values[len(variable_values:]
then at the end expected_values should be empty.
compiler_opt/es/policy_utils_test.py
Outdated
sm = tf.saved_model.load(policy_save_path + '/policy') | ||
self.assertNotIsInstance(sm, tf_policy.TFPolicy) | ||
policy_utils.set_vectorized_parameters_for_policy(sm, params) | ||
idx = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same idea with idx... and same comment further below about naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, some comments before landing
class VectorTest(absltest.TestCase): | ||
|
||
def test_set_vectorized_parameters_for_policy(self): | ||
# create a policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok - could you also add a reference to #280 over each test, easier to avoid forgetting
# set the values of the policy variables | ||
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params) | ||
# iterate through variables and check their shapes and values | ||
expected_values = [*VectorTest.params] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a comment that we want to destructively go over the expected values, hence the deep copy.
compiler_opt/es/policy_utils_test.py
Outdated
# save the policy | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`os.path.join(testing_path, 'temp_output', 'policy')
i.e. don't assume '/' is the separator.
also, can we call 'policy' something else, it's a bit confusing how then we add again a 'policy' on line 144
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I made a variable POLICY_NAME and used it for the name in the dict on lines like 126 here for clarity. Should I also change lines with quantile_file_dir='compiler_opt/rl/inlining/vocab/'
to use join since the separator is hardcoded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine, we'll remove it later bc #280 anyway.
compiler_opt/es/policy_utils_test.py
Outdated
self.assertEmpty(expected_values) | ||
|
||
# get saved model to test a loaded policy | ||
sm = tf.saved_model.load(policy_save_path + '/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.path.join instead of +
compiler_opt/es/policy_utils_test.py
Outdated
# save the policy | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about path and names
compiler_opt/es/policy_utils_test.py
Outdated
np.testing.assert_array_almost_equal(output, VectorTest.params) | ||
|
||
# get saved model to test a loaded policy | ||
sm = tf.saved_model.load(policy_save_path + '/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as before
compiler_opt/es/policy_utils_test.py
Outdated
# save the policy | ||
saver = policy_saver.PolicySaver({'policy': policy}) | ||
testing_path = self.create_tempdir() | ||
policy_save_path = os.path.join(testing_path, 'temp_output/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here too
compiler_opt/es/policy_utils_test.py
Outdated
tf_params = policy_utils.get_vectorized_parameters_from_policy(policy) | ||
|
||
# get loaded policy | ||
sm = tf.saved_model.load(policy_save_path + '/policy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here too
No description provided.