From 25b4e6ba47bbc8a6680c0581d1bbd8ae0fbbf1e3 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Mon, 14 Oct 2024 18:42:15 +0000 Subject: [PATCH 1/6] Initial ExplorationWorker commit. ExplorationWorker implements the interactive clang compilation together with exploration for a given module. This commit adds the class with its compile_module function which will compile the module with a given compiler policy. --- compiler_opt/rl/generate_bc_trajectories.py | 239 +++++++++++++++++- .../rl/generate_bc_trajectories_test.py | 193 ++++++++++++++ 2 files changed, 431 insertions(+), 1 deletion(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 79d65f2d..977ae15c 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -14,12 +14,101 @@ # limitations under the License. """Module for running compilation and collect data for behavior cloning.""" -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +from absl import logging +import os +import shutil import numpy as np import tensorflow as tf from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step +from tf_agents.specs import tensor_spec + +from compiler_opt.distributed import worker +from compiler_opt.rl import corpus +from compiler_opt.rl import env + + +def add_int_feature( + sequence_example: tf.train.SequenceExample, + feature_value: np.int64, + feature_name: str, +): + """Add an int feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one belonging to + the instance + feature_value: tf.int64 value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.int64_list.value + lst.extend([feature_value]) + + +def add_float_feature( + sequence_example: tf.train.SequenceExample, + feature_value: np.float32, + feature_name: str, +): + """Add a float feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one belonging to + the instance + feature_value: tf.int64 value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.float_list.value + lst.extend([feature_value]) + + +def add_string_feature( + sequence_example: tf.train.SequenceExample, + feature_value: str, + feature_name: str, +): + """Add a string feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one + feature_value: tf.string value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.bytes_list.value + lst.extend([feature_value.encode('utf-8')]) + + +def add_feature_list(seq_example: tf.train.SequenceExample, + feature_list: List[Any], feature_name: str): + """Add the feature_list to the sequence example under feature name. + + Args: + seq_example: sequence example to update + feature_list: list of feature values to add to seq_example + feature_name: name of the feature to add the list under + """ + if (type(feature_list[0]) not in [ + np.dtype(np.int64), + np.dtype(np.float32), + str, + ]): + raise AssertionError(f'''Unsupported type for feautre {0} of type {1}. + Supported types are np.int64, np.float32, str'''.format( + feature_name, type(feature_list[0]))) + if isinstance(feature_list[0], np.float32): + add_function = add_float_feature + elif isinstance(feature_list[0], (int, np.int64)): + add_function = add_int_feature + else: + add_function = add_string_feature + for feature in feature_list: + add_function(seq_example, feature, feature_name) class ExplorationWithPolicy: @@ -102,3 +191,151 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray: break self.curr_step += 1 return policy_action + + +class ExplorationWorker(worker.Worker): + """Class which implements the exploration for the given module. + + Attributes: + loaded_module_spec: the module to be compiled and explored + use_greedy: indicates if the default/greedy policy is used to compile the + module + env: MLGO environment. + exploration_frac: how often to explore in a trajectory + max_exploration_steps: maximum number of exploration steps + exploration_policy_distr: distribution function from exploration policy. + reward_key: which reward binary to use, must be specified as part of + additional task args (kwargs). + """ + + def __init__( + self, + loaded_module_spec: corpus.LoadedModuleSpec, + clang_path: str, + mlgo_task: Type[env.MLGOTask], + use_greedy: bool = False, + exploration_frac: float = 1.0, + max_exploration_steps: int = 10, + exploration_policy_distr=None, + obs_action_specs: Optional[Tuple[time_step.TimeStep, + tensor_spec.BoundedTensorSpec,]] = None, + **kwargs, + ): + self.loaded_module_spec = loaded_module_spec + self.use_greedy = use_greedy + if not obs_action_specs: + obs_spec = None + action_spec = None + else: + obs_spec = obs_action_specs[0].observation + action_spec = obs_action_specs[1] + + if 'reward_key' not in kwargs: + raise KeyError( + 'reward_key not specified in ExplorationWorker initialization.') + self.reward_key = kwargs['reward_key'] + kwargs.pop('reward_key', None) + self.working_dir = None + + self.env = env.MLGOEnvironmentBase( + clang_path=clang_path, + task_type=mlgo_task, + obs_spec=obs_spec, + action_spec=action_spec, + ) + if self.env.action_spec: + if self.env.action_spec.dtype != tf.int64: + raise TypeError( + f'Environment action_spec type {0} does not match tf.int64'.format( + self.env.action_spec.dtype)) + self.exploration_frac = exploration_frac + self.max_exploration_steps = max_exploration_steps + self.exploration_policy_distr = exploration_policy_distr + logging.info('Reward key in exploration worker: %s', self.reward_key) + + def compile_module( + self, + policy: Callable[[Optional[time_step.TimeStep]], np.ndarray], + ) -> tf.train.SequenceExample: + """Compiles the module with the given policy and outputs a seq. example. + + Args: + policy: policy to compile with + + Returns: + sequence_example: a tf.train.SequenceExample containing the trajectory + from compilation. + """ + sequence_example = tf.train.SequenceExample() + curr_obs_dict = self.env.reset(self.loaded_module_spec) + try: + curr_obs = curr_obs_dict.obs + self._process_obs(curr_obs, sequence_example) + while curr_obs_dict.step_type != env.StepType.LAST: + timestep = self._create_timestep(curr_obs_dict) + action = policy(timestep) + add_int_feature(sequence_example, int(action), 'action') + curr_obs_dict = self.env.step(action) + curr_obs = curr_obs_dict.obs + if curr_obs_dict.step_type == env.StepType.LAST: + break + self._process_obs(curr_obs, sequence_example) + except AssertionError as e: + logging.error('AssertionError: %s', e) + horizon = len(sequence_example.feature_lists.feature_list['action'].feature) + self.working_dir = curr_obs_dict.working_dir + if horizon <= 0: + working_dir_head = os.path.split(self.working_dir)[0] + shutil.rmtree(working_dir_head) + if horizon <= 0: + raise ValueError( + f'Policy did not take any inlining decision for module {0}.'.format( + self.loaded_module_spec.name)) + if curr_obs_dict.step_type != env.StepType.LAST: + raise ValueError( + f'Compilation loop exited at step type {0} before last step'.format( + curr_obs_dict.step_type)) + native_size = curr_obs_dict.score_policy[self.reward_key] + native_size_list = np.float32(native_size) * np.float32(np.ones(horizon)) + add_feature_list(sequence_example, native_size_list, 'reward') + module_name_list = [self.loaded_module_spec.name for _ in range(horizon)] + add_feature_list(sequence_example, module_name_list, 'module_name') + return sequence_example + + def _create_timestep(self, curr_obs_dict: env.TimeStep): + curr_obs = curr_obs_dict.obs + curr_obs_step = curr_obs_dict.step_type + step_type_converter = { + env.StepType.FIRST: 0, + env.StepType.MID: 1, + env.StepType.LAST: 2, + } + if curr_obs_dict.step_type == env.StepType.LAST: + reward = np.array(curr_obs_dict.score_policy[self.reward_key]) + else: + reward = np.array(0.) + curr_obs_step = step_type_converter[curr_obs_step] + timestep = time_step.TimeStep( + step_type=tf.convert_to_tensor([curr_obs_step], + dtype=tf.int32, + name='step_type'), + reward=tf.convert_to_tensor([reward], dtype=tf.float32, name='reward'), + discount=tf.convert_to_tensor([0.0], dtype=tf.float32, name='discount'), + observation=curr_obs, + ) + return timestep + + def _process_obs(self, curr_obs, sequence_example): + for curr_obs_feature_name in curr_obs: + if not self.env.obs_spec: + obs_dtype = tf.int64 + else: + if curr_obs_feature_name not in self.env.obs_spec.keys(): + raise AssertionError(f'Feature name {0} not in obs_spec {1}'.format( + curr_obs_feature_name, self.env.obs_spec.keys())) + obs_dtype = self.env.obs_spec[curr_obs_feature_name].dtype + curr_obs_feature = curr_obs[curr_obs_feature_name] + curr_obs[curr_obs_feature_name] = tf.convert_to_tensor( + curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name) + add_feature_list(sequence_example, curr_obs_feature, + curr_obs_feature_name) diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py index a4ae9b6c..5a277b58 100644 --- a/compiler_opt/rl/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -15,6 +15,7 @@ """Tests for compiler_opt.rl.generate_bc_trajectories.""" from typing import List +from unittest import mock import numpy as np import tensorflow as tf @@ -22,7 +23,11 @@ from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step +from google.protobuf import text_format + from compiler_opt.rl import generate_bc_trajectories +from compiler_opt.rl import env +from compiler_opt.rl import env_test _eps = 1e-5 @@ -152,3 +157,191 @@ def explore_on_feature_2_val(feature_val): for state in _get_state_list(): _ = explore_with_policy.get_advice(state)[0] self.assertEqual(1, explore_with_policy.explore_step) + + +class AddToFeatureListsTest(tf.test.TestCase): + + def test_add_int_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { int64_list { value: 1 } } + feature { int64_list { value: 2 } } + } + } + feature_list { + key: "feature_1" + value { + feature { int64_list { value: 3 } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=1, + feature_name='feature_0') + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=2, + feature_name='feature_0') + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=3, + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + def test_add_float_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { float_list { value: .1 } } + feature { float_list { value: .2 } } + } + } + feature_list { + key: "feature_1" + value { + feature { float_list { value: .3 } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.1, + feature_name='feature_0') + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.2, + feature_name='feature_0') + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.3, + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + def test_add_string_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { bytes_list { value: "1" } } + feature { bytes_list { value: "2" } } + } + } + feature_list { + key: "feature_1" + value { + feature { bytes_list { value: "3" } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='1', + feature_name='feature_0') + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='2', + feature_name='feature_0') + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='3', + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + +class ExplorationWorkerTest(tf.test.TestCase): + # pylint: disable=protected-access + @mock.patch('subprocess.Popen') + def test_create_timestep(self, mock_popen): + mock_popen.side_effect = env_test.mock_interactive_clang + + def create_timestep_comp(step_type, reward, obs): + timestep_comp = time_step.TimeStep( + step_type=tf.convert_to_tensor([step_type], + dtype=tf.int32, + name='step_type'), + reward=tf.convert_to_tensor([reward], dtype=tf.float32, + name='reward'), + discount=tf.convert_to_tensor([0.0], + dtype=tf.float32, + name='discount'), + observation=obs, + ) + return timestep_comp + + test_env = env.MLGOEnvironmentBase( + clang_path=env_test._CLANG_PATH, + task_type=env_test.MockTask, + obs_spec={}, + action_spec={}, + ) + + exploration_worker = generate_bc_trajectories.ExplorationWorker( + loaded_module_spec=env_test._MOCK_MODULE, + clang_path=env_test._CLANG_PATH, + mlgo_task=env_test.MockTask, + reward_key='default', + ) + + curr_step_obs = test_env.reset(env_test._MOCK_MODULE) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(0, 0., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + for step_itr in range(env_test._NUM_STEPS - 1): + del step_itr + curr_step_obs = test_env.step(np.array([1], dtype=np.int64)) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(1, 0., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + curr_step_obs = test_env.step(np.array([1], dtype=np.int64)) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(2, 47., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + @mock.patch('subprocess.Popen') + def test_compile_module(self, mock_popen): + mock_popen.side_effect = env_test.mock_interactive_clang + + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + + exploration_worker = generate_bc_trajectories.ExplorationWorker( + loaded_module_spec=env_test._MOCK_MODULE, + clang_path=env_test._CLANG_PATH, + mlgo_task=env_test.MockTask, + reward_key='default', + ) + + seq_example = exploration_worker.compile_module(_policy) + self.assertEqual(seq_example, seq_example_comp) From a2d39b2f9258aae20e6dc7c6dfede28244b9daf0 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Mon, 14 Oct 2024 18:53:55 +0000 Subject: [PATCH 2/6] Suppressing pytype error for protobuf import. --- compiler_opt/rl/generate_bc_trajectories_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py index 5a277b58..aae60f7b 100644 --- a/compiler_opt/rl/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -23,7 +23,7 @@ from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step -from google.protobuf import text_format +from google.protobuf import text_format # pytype: disable=pyi-error from compiler_opt.rl import generate_bc_trajectories from compiler_opt.rl import env From 59221f104a47af8e316aa34746717cec782779ec Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Wed, 16 Oct 2024 22:37:37 +0000 Subject: [PATCH 3/6] Addressing comments by @mtrofin. --- compiler_opt/rl/generate_bc_trajectories.py | 96 +++++++++++++-------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 977ae15c..90727232 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type from absl import logging +import dataclasses import os import shutil @@ -41,7 +42,8 @@ def add_int_feature( Args: sequence_example: sequence example to use instead of the one belonging to the instance - feature_value: tf.int64 value of feature + feature_value: np.int64 value of feature, this is the required type by + tf.train.SequenceExample for an int list feature_name: name of feature """ f = sequence_example.feature_lists.feature_list[feature_name].feature.add() @@ -59,7 +61,8 @@ def add_float_feature( Args: sequence_example: sequence example to use instead of the one belonging to the instance - feature_value: tf.int64 value of feature + feature_value: np.float32 value of feature, this is the required type by + tf.train.SequenceExample for an float list feature_name: name of feature """ f = sequence_example.feature_lists.feature_list[feature_name].feature.add() @@ -111,6 +114,13 @@ def add_feature_list(seq_example: tf.train.SequenceExample, add_function(seq_example, feature, feature_name) +@dataclasses.dataclass +class SequenceExampleFeatureNames: + action: str = 'action' + reward: str = 'reward' + module_name: str = 'module_name' + + class ExplorationWithPolicy: """Policy which selects states for exploration. @@ -219,10 +229,11 @@ def __init__( exploration_policy_distr=None, obs_action_specs: Optional[Tuple[time_step.TimeStep, tensor_spec.BoundedTensorSpec,]] = None, + reward_key: str = '', **kwargs, ): - self.loaded_module_spec = loaded_module_spec - self.use_greedy = use_greedy + self._loaded_module_spec = loaded_module_spec + self._use_greedy = use_greedy if not obs_action_specs: obs_spec = None action_spec = None @@ -230,28 +241,28 @@ def __init__( obs_spec = obs_action_specs[0].observation action_spec = obs_action_specs[1] - if 'reward_key' not in kwargs: - raise KeyError( + if reward_key == '': + raise TypeError( 'reward_key not specified in ExplorationWorker initialization.') - self.reward_key = kwargs['reward_key'] + self._reward_key = reward_key kwargs.pop('reward_key', None) - self.working_dir = None + self._working_dir = None - self.env = env.MLGOEnvironmentBase( + self._env = env.MLGOEnvironmentBase( clang_path=clang_path, task_type=mlgo_task, obs_spec=obs_spec, action_spec=action_spec, ) - if self.env.action_spec: - if self.env.action_spec.dtype != tf.int64: + if self._env.action_spec: + if self._env.action_spec.dtype != tf.int64: raise TypeError( f'Environment action_spec type {0} does not match tf.int64'.format( - self.env.action_spec.dtype)) - self.exploration_frac = exploration_frac - self.max_exploration_steps = max_exploration_steps - self.exploration_policy_distr = exploration_policy_distr - logging.info('Reward key in exploration worker: %s', self.reward_key) + self._env.action_spec.dtype)) + self._exploration_frac = exploration_frac + self._max_exploration_steps = max_exploration_steps + self._exploration_policy_distr = exploration_policy_distr + logging.info('Reward key in exploration worker: %s', self._reward_key) def compile_module( self, @@ -264,42 +275,51 @@ def compile_module( Returns: sequence_example: a tf.train.SequenceExample containing the trajectory - from compilation. + from compilation. In addition to the features returned from the env + tbe sequence_example adds the following extra features: action, + reward and module_name. action is the action taken at any given step, + reward is the reward specified by reward_key, not necessarily the + reward returned by the environment and module_name is the name of + the module processed by the compiler. """ sequence_example = tf.train.SequenceExample() - curr_obs_dict = self.env.reset(self.loaded_module_spec) + curr_obs_dict = self._env.reset(self._loaded_module_spec) try: curr_obs = curr_obs_dict.obs self._process_obs(curr_obs, sequence_example) while curr_obs_dict.step_type != env.StepType.LAST: timestep = self._create_timestep(curr_obs_dict) action = policy(timestep) - add_int_feature(sequence_example, int(action), 'action') - curr_obs_dict = self.env.step(action) + add_int_feature(sequence_example, int(action), + SequenceExampleFeatureNames.action) + curr_obs_dict = self._env.step(action) curr_obs = curr_obs_dict.obs if curr_obs_dict.step_type == env.StepType.LAST: break self._process_obs(curr_obs, sequence_example) except AssertionError as e: logging.error('AssertionError: %s', e) - horizon = len(sequence_example.feature_lists.feature_list['action'].feature) - self.working_dir = curr_obs_dict.working_dir + horizon = len(sequence_example.feature_lists.feature_list[ + SequenceExampleFeatureNames.action].feature) + self._working_dir = curr_obs_dict.working_dir if horizon <= 0: - working_dir_head = os.path.split(self.working_dir)[0] + working_dir_head = os.path.split(self._working_dir)[0] shutil.rmtree(working_dir_head) if horizon <= 0: raise ValueError( f'Policy did not take any inlining decision for module {0}.'.format( - self.loaded_module_spec.name)) + self._loaded_module_spec.name)) if curr_obs_dict.step_type != env.StepType.LAST: raise ValueError( f'Compilation loop exited at step type {0} before last step'.format( curr_obs_dict.step_type)) - native_size = curr_obs_dict.score_policy[self.reward_key] - native_size_list = np.float32(native_size) * np.float32(np.ones(horizon)) - add_feature_list(sequence_example, native_size_list, 'reward') - module_name_list = [self.loaded_module_spec.name for _ in range(horizon)] - add_feature_list(sequence_example, module_name_list, 'module_name') + reward = curr_obs_dict.score_policy[self._reward_key] + reward_list = np.float32(reward) * np.float32(np.ones(horizon)) + add_feature_list(sequence_example, reward_list, + SequenceExampleFeatureNames.reward) + module_name_list = [self._loaded_module_spec.name for _ in range(horizon)] + add_feature_list(sequence_example, module_name_list, + SequenceExampleFeatureNames.module_name) return sequence_example def _create_timestep(self, curr_obs_dict: env.TimeStep): @@ -311,7 +331,7 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep): env.StepType.LAST: 2, } if curr_obs_dict.step_type == env.StepType.LAST: - reward = np.array(curr_obs_dict.score_policy[self.reward_key]) + reward = np.array(curr_obs_dict.score_policy[self._reward_key]) else: reward = np.array(0.) curr_obs_step = step_type_converter[curr_obs_step] @@ -327,13 +347,21 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep): def _process_obs(self, curr_obs, sequence_example): for curr_obs_feature_name in curr_obs: - if not self.env.obs_spec: + if not self._env.obs_spec: obs_dtype = tf.int64 else: - if curr_obs_feature_name not in self.env.obs_spec.keys(): + if curr_obs_feature_name not in self._env.obs_spec.keys(): raise AssertionError(f'Feature name {0} not in obs_spec {1}'.format( - curr_obs_feature_name, self.env.obs_spec.keys())) - obs_dtype = self.env.obs_spec[curr_obs_feature_name].dtype + curr_obs_feature_name, self._env.obs_spec.keys())) + if curr_obs_feature_name in [ + SequenceExampleFeatureNames.action, + SequenceExampleFeatureNames.reward, + SequenceExampleFeatureNames.module_name + ]: + raise AssertionError( + f'Feature name {0} already part of SequenceExampleFeatureNames' + .format(curr_obs_feature_name, self._env.obs_spec.keys())) + obs_dtype = self._env.obs_spec[curr_obs_feature_name].dtype curr_obs_feature = curr_obs[curr_obs_feature_name] curr_obs[curr_obs_feature_name] = tf.convert_to_tensor( curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name) From 0474b603d66ed95ebf10c4ed42619974acdf2a1c Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Thu, 17 Oct 2024 00:34:12 +0000 Subject: [PATCH 4/6] Fixing nits. --- compiler_opt/rl/generate_bc_trajectories.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 90727232..49f21fe4 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -116,6 +116,7 @@ def add_feature_list(seq_example: tf.train.SequenceExample, @dataclasses.dataclass class SequenceExampleFeatureNames: + """Feature names for features that are always added to seq example.""" action: str = 'action' reward: str = 'reward' module_name: str = 'module_name' From 6aa7c828a2769b85f063402f4caf3e74d6a9875d Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Fri, 18 Oct 2024 15:26:58 +0000 Subject: [PATCH 5/6] exploration_worker.explore_function commit. Explores the module using the given policy and the exploration distr. This implementation assumes the action set is only {0,1} and will just switch the action played at explore_step. --- compiler_opt/rl/generate_bc_trajectories.py | 221 +++++++++++++++--- .../rl/generate_bc_trajectories_test.py | 142 +++++++++-- 2 files changed, 313 insertions(+), 50 deletions(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 49f21fe4..7a889327 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -14,13 +14,14 @@ # limitations under the License. """Module for running compilation and collect data for behavior cloning.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator from absl import logging import dataclasses import os import shutil +import math import numpy as np import tensorflow as tf from tf_agents.trajectories import policy_step @@ -32,6 +33,30 @@ from compiler_opt.rl import env +@dataclasses.dataclass +class SequenceExampleFeatureNames: + """Feature names for features that are always added to seq example.""" + action: str = 'action' + reward: str = 'reward' + module_name: str = 'module_name' + + +def get_loss(seq_example: tf.train.SequenceExample, + reward_key: str = SequenceExampleFeatureNames.reward) -> int: + """Return the last loss/reward of a trajectory written in a SequenceExample. + + Args: + seq_example: tf.train.SequenceExample which contains the trajectory with + all features, including a reward feature + reward_key: the name of the feature that contains the loss/reward. + + Returns: + The loss/reward of a trajectory written in a SequenceExample. + """ + return (seq_example.feature_lists.feature_list[reward_key].feature[-1] + .float_list.value[0]) + + def add_int_feature( sequence_example: tf.train.SequenceExample, feature_value: np.int64, @@ -114,14 +139,6 @@ def add_feature_list(seq_example: tf.train.SequenceExample, add_function(seq_example, feature, feature_name) -@dataclasses.dataclass -class SequenceExampleFeatureNames: - """Feature names for features that are always added to seq example.""" - action: str = 'action' - reward: str = 'reward' - module_name: str = 'module_name' - - class ExplorationWithPolicy: """Policy which selects states for exploration. @@ -140,6 +157,7 @@ class ExplorationWithPolicy: explore_policy: randomized policy which is used to compute the gap curr_step: current step of the trajectory explore_step: current candidate for exploration step + explore_state: current candidate state for exploration at explore_step gap: current difference at explore step between probability of most likely action according to explore_policy and second most likely action explore_on_features: dict of feature names and functions which specify @@ -154,13 +172,15 @@ def __init__( explore_on_features: Optional[Dict[str, Callable[[tf.Tensor], bool]]] = None, ): - self.replay_prefix = replay_prefix - self.policy = policy - self.explore_policy = explore_policy - self.curr_step = 0 - self.explore_step = 0 - self.gap = np.inf - self.explore_on_features = explore_on_features + self.explore_step: int = len(replay_prefix) - 1 + self.explore_state: time_step.TimeStep | None = None + self._replay_prefix = replay_prefix + self._policy = policy + self._explore_policy = explore_policy + self._curr_step = 0 + self._gap = np.inf + self._explore_on_features: Optional[Dict[str, Callable[ + [tf.Tensor], bool]]] = explore_on_features self._stop_exploration = False def _compute_gap(self, distr: np.ndarray) -> np.float32: @@ -179,28 +199,29 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray: policy_action: action to take at the current state. """ - if self.curr_step < len(self.replay_prefix): - self.curr_step += 1 - return np.array(self.replay_prefix[self.curr_step - 1]) - policy_action = self.policy(state) + if self._curr_step < len(self._replay_prefix): + self._curr_step += 1 + return np.array(self._replay_prefix[self._curr_step - 1]) + policy_action = self._policy(state) # explore_policy(state) should play at least one action per state and so # self.explore_policy(state).action.logits should have at least one entry - distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0] + distr = tf.nn.softmax(self._explore_policy(state).action.logits).numpy()[0] curr_gap = self._compute_gap(distr) # selecting explore_step is done based on smallest encountered gap in the # play of self.policy. This logic can be changed to have different type # of exploration. if (not self._stop_exploration and distr.shape[0] > 1 and - self.gap > curr_gap): - self.gap = curr_gap - self.explore_step = self.curr_step - if not self._stop_exploration and self.explore_on_features is not None: - for feature_name, explore_on_feature in self.explore_on_features.items(): + self._gap > curr_gap): + self._gap = curr_gap + self.explore_step = self._curr_step + self.explore_state = state + if not self._stop_exploration and self._explore_on_features is not None: + for feature_name, explore_on_feature in self._explore_on_features.items(): if explore_on_feature(state.observation[feature_name]): - self.explore_step = self.curr_step + self.explore_step = self._curr_step self._stop_exploration = True break - self.curr_step += 1 + self._curr_step += 1 return policy_action @@ -214,7 +235,10 @@ class ExplorationWorker(worker.Worker): env: MLGO environment. exploration_frac: how often to explore in a trajectory max_exploration_steps: maximum number of exploration steps - exploration_policy_distr: distribution function from exploration policy. + max_horizon_to_explore: if the horizon under policy is greater than this + we do not do exploration + explore_on_features: dict of feature names and functions which specify + when to explore on the respective feature reward_key: which reward binary to use, must be specified as part of additional task args (kwargs). """ @@ -224,17 +248,17 @@ def __init__( loaded_module_spec: corpus.LoadedModuleSpec, clang_path: str, mlgo_task: Type[env.MLGOTask], - use_greedy: bool = False, exploration_frac: float = 1.0, max_exploration_steps: int = 10, - exploration_policy_distr=None, + max_horizon_to_explore=np.inf, + explore_on_features: Optional[Dict[str, Callable[[tf.Tensor], + bool]]] = None, obs_action_specs: Optional[Tuple[time_step.TimeStep, tensor_spec.BoundedTensorSpec,]] = None, reward_key: str = '', **kwargs, ): self._loaded_module_spec = loaded_module_spec - self._use_greedy = use_greedy if not obs_action_specs: obs_spec = None action_spec = None @@ -262,7 +286,8 @@ def __init__( self._env.action_spec.dtype)) self._exploration_frac = exploration_frac self._max_exploration_steps = max_exploration_steps - self._exploration_policy_distr = exploration_policy_distr + self._max_horizon_to_explore = max_horizon_to_explore + self._explore_on_features = explore_on_features logging.info('Reward key in exploration worker: %s', self._reward_key) def compile_module( @@ -323,6 +348,136 @@ def compile_module( SequenceExampleFeatureNames.module_name) return sequence_example + def explore_function( + self, + policy: Callable[[Optional[time_step.TimeStep]], np.ndarray], + explore_policy: Optional[Callable[[time_step.TimeStep], + policy_step.PolicyStep]] = None, + ) -> Tuple[List[tf.train.SequenceExample], List[str], int, float]: + """Explores the module using the given policy and the exploration distr. + + Args: + policy: policy which acts on all states outside of the exploration states. + explore_policy: randomized policy which is used to compute the gap for + exploration and can be used for deciding which actions to explore at + the exploration state. + + Returns: + seq_example_list: a tf.train.SequenceExample list containing the all + trajectories from exploration. + working_dir_names: the directories of the compiled binaries + loss_idx: idx of the smallest loss trajectory in the seq_example_list. + base_seq_loss: loss of the trajectory compiled with policy. + """ + seq_example_list = [] + working_dir_names = [] + loss_idx = 0 + exploration_steps = 0 + + if not explore_policy: + base_seq = self.compile_module(policy) + seq_example_list.append(base_seq) + working_dir_names.append(self._working_dir) + return ( + seq_example_list, + working_dir_names, + loss_idx, + get_loss(base_seq), + ) + + base_policy = ExplorationWithPolicy( + [], + policy, + explore_policy, + self._explore_on_features, + ) + base_seq = self.compile_module(base_policy.get_advice) + seq_example_list.append(base_seq) + working_dir_names.append(self._working_dir) + base_seq_loss = get_loss(base_seq) + horizon = len(base_seq.feature_lists.feature_list[ + SequenceExampleFeatureNames.action].feature) + num_states = int(math.ceil(self._exploration_frac * horizon)) + num_states = min(num_states, self._max_exploration_steps) + if num_states < 1 or horizon > self._max_horizon_to_explore: + return seq_example_list, working_dir_names, loss_idx, base_seq_loss + + seq_losses = [base_seq_loss] + for num_steps in range(num_states): + explore_step = base_policy.explore_step + if explore_step >= horizon: + break + replay_prefix = base_seq.feature_lists.feature_list[ + SequenceExampleFeatureNames.action].feature + replay_prefix = self._build_replay_prefix_list( + replay_prefix[:explore_step + 1]) + explore_state = base_policy.explore_state + for base_seq, base_policy in self.explore_at_state_generator( + replay_prefix, explore_step, explore_state, policy, explore_policy): + exploration_steps += 1 + seq_example_list.append(base_seq) + working_dir_names.append(self._working_dir) + seq_loss = get_loss(base_seq) + seq_losses.append(seq_loss) + # <= biases towards more exploration in the dataset, < towards less expl + if seq_loss < base_seq_loss: + base_seq_loss = seq_loss + loss_idx = num_steps + 1 + logging.info('module exploration losses: %s', seq_losses) + if exploration_steps > self._max_exploration_steps: + return seq_example_list, working_dir_names, loss_idx, base_seq_loss + horizon = len(base_seq.feature_lists.feature_list[ + SequenceExampleFeatureNames.action].feature) + # check if we are at the end of the trajectory and the last was explored + if (explore_step == base_policy.explore_step and + explore_step == horizon - 1): + return seq_example_list, working_dir_names, loss_idx, base_seq_loss + + return seq_example_list, working_dir_names, loss_idx, base_seq_loss + + def explore_at_state_generator( + self, replay_prefix: List[np.ndarray], explore_step: int, + explore_state: time_step.TimeStep, + policy: Callable[[Optional[time_step.TimeStep]], np.ndarray], + explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep] + ) -> Generator[Tuple[tf.train.SequenceExample, ExplorationWithPolicy], None, + None]: + """Generate sequence examples and next exploration policy while exploring. + + Generator that defines how to explore at the given explore_step. This + implementation assumes the action set is only {0,1} and will just switch + the action played at explore_step. + + Args: + replay_prefix: a replay buffer of actions + explore_step: exploration step in the previous compiled trajectory + explore_state: state for exploration at explore_step + policy: policy which acts on all states outside of the exploration states. + explore_policy: randomized policy which is used to compute the gap for + exploration and can be used for deciding which actions to explore at + the exploration state. + + Yields: + base_seq: a tf.train.SequenceExample containing a compiled trajectory + base_policy: the policy used to determine the next exploration step + """ + del explore_state + replay_prefix[explore_step] = 1 - replay_prefix[explore_step] + base_policy = ExplorationWithPolicy( + replay_prefix, + policy, + explore_policy, + self._explore_on_features, + ) + base_seq = self.compile_module(base_policy.get_advice) + yield base_seq, base_policy + + def _build_replay_prefix_list(self, seq_ex): + ret_list = [] + for int_list in seq_ex: + ret_list.append(int_list.int64_list.value[0]) + return ret_list + def _create_timestep(self, curr_obs_dict: env.TimeStep): curr_obs = curr_obs_dict.obs curr_obs_step = curr_obs_dict.step_type diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py index aae60f7b..4c2acd5f 100644 --- a/compiler_opt/rl/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -81,35 +81,37 @@ def _policy(state: time_step.TimeStep) -> np.ndarray: return np.mod(feature_sum, 5) -def _explore_policy(state: time_step.TimeStep) -> policy_step.PolicyStep: - probs = [ - 0.5 * float(state.observation['feature_3'].numpy()), - 1 - 0.5 * float(state.observation['feature_3'].numpy()) - ] - logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]] - return policy_step.PolicyStep( - action=tfp.distributions.Categorical(logits=logits)) - - class ExplorationWithPolicyTest(tf.test.TestCase): + def _explore_policy(self, + state: time_step.TimeStep) -> policy_step.PolicyStep: + probs = [ + 0.5 * float(state.observation['feature_3'].numpy()), + 1 - 0.5 * float(state.observation['feature_3'].numpy()) + ] + logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]] + return policy_step.PolicyStep( + action=tfp.distributions.Categorical(logits=logits)) + def test_explore_policy(self): prob = 1. state = _get_state_list()[3] logits = [[0.0, tf.math.log(prob / (1.0 - prob + _eps))]] action = tfp.distributions.Categorical(logits=logits) - self.assertAllClose(action.logits, _explore_policy(state).action.logits) + self.assertAllClose(action.logits, + self._explore_policy(state).action.logits) def test_explore_with_gap(self): + # pylint: disable=protected-access explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( replay_prefix=[np.array([1])], policy=_policy, - explore_policy=_explore_policy, + explore_policy=self._explore_policy, ) for state in _get_state_list(): _ = explore_with_policy.get_advice(state)[0] - self.assertAllClose(0, explore_with_policy.gap, atol=2 * _eps) + self.assertAllClose(0, explore_with_policy._gap, atol=2 * _eps) self.assertEqual(2, explore_with_policy.explore_step) explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( @@ -117,12 +119,12 @@ def test_explore_with_gap(self): np.array([1]), np.array([1])], policy=_policy, - explore_policy=_explore_policy, + explore_policy=self._explore_policy, ) for state in _get_state_list(): _ = explore_with_policy.get_advice(state)[0] - self.assertAllClose(1, explore_with_policy.gap, atol=2 * _eps) + self.assertAllClose(1, explore_with_policy._gap, atol=2 * _eps) self.assertEqual(3, explore_with_policy.explore_step) def test_explore_with_feature(self): @@ -141,7 +143,7 @@ def explore_on_feature_2_val(feature_val): explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( replay_prefix=[], policy=_policy, - explore_policy=_explore_policy, + explore_policy=self._explore_policy, explore_on_features=explore_on_features) for state in _get_state_list(): _ = explore_with_policy.get_advice(state)[0] @@ -150,7 +152,7 @@ def explore_on_feature_2_val(feature_val): explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( replay_prefix=[np.array([1])], policy=_policy, - explore_policy=_explore_policy, + explore_policy=self._explore_policy, explore_on_features=explore_on_features, ) @@ -345,3 +347,109 @@ def test_compile_module(self, mock_popen): seq_example = exploration_worker.compile_module(_policy) self.assertEqual(seq_example, seq_example_comp) + + def _get_seq_example_list_comp(self): + seq_example_list_comp = [] + + # no exploration + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + seq_example_list_comp.append(seq_example_comp) + + # first exploration trajectory, tests explore with gap + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + if i == 4: + generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action') + else: + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + seq_example_list_comp.append(seq_example_comp) + + # second exploration trajectory, tests explore on feature + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + if i == 4: + generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action') + elif i == 5: + generate_bc_trajectories.add_int_feature(seq_example_comp, 1, 'action') + else: + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + seq_example_list_comp.append(seq_example_comp) + + # third exploration trajectory, tests explore with gap + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + if i == 4: + generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action') + elif i == 5: + generate_bc_trajectories.add_int_feature(seq_example_comp, 1, 'action') + elif i == 9: + generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action') + else: + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + seq_example_list_comp.append(seq_example_comp) + + return seq_example_list_comp + + @mock.patch('subprocess.Popen') + def test_explore_function(self, mock_popen): + mock_popen.side_effect = env_test.mock_interactive_clang + + def _explore_on_feature_func(feature_val) -> bool: + return feature_val[0] in [4, 5] + + exploration_worker = generate_bc_trajectories.ExplorationWorker( + loaded_module_spec=env_test._MOCK_MODULE, + clang_path=env_test._CLANG_PATH, + mlgo_task=env_test.MockTask, + reward_key='default', + explore_on_features={'times_called': _explore_on_feature_func}) + + def _explore_policy(state: time_step.TimeStep): + times_called = state.observation['times_called'][0] + # will explore every 4-th step + logits = [[ + 4.0 + 1e-3 * float(env_test._NUM_STEPS - times_called), + float(np.mod(times_called, 5)) + ]] + return policy_step.PolicyStep( + action=tfp.distributions.Categorical(logits=logits)) + + (seq_example_list, working_dir_names, loss_idx, + base_seq_loss) = exploration_worker.explore_function( + _policy, _explore_policy) + del working_dir_names + + self.assertEqual(loss_idx, 0) + self.assertEqual(base_seq_loss, 47.0) + seq_example_list_comp = self._get_seq_example_list_comp() + self.assertListEqual(seq_example_list, seq_example_list_comp) From d9c03653d378d4cf46c175fb7a3ebf9a587a15c5 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Fri, 18 Oct 2024 15:33:58 +0000 Subject: [PATCH 6/6] Fix pylint. --- compiler_opt/rl/generate_bc_trajectories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 7a889327..f8085a4a 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -173,7 +173,7 @@ def __init__( bool]]] = None, ): self.explore_step: int = len(replay_prefix) - 1 - self.explore_state: time_step.TimeStep | None = None + self.explore_state: Optional[time_step.TimeStep] = None self._replay_prefix = replay_prefix self._policy = policy self._explore_policy = explore_policy