From 25b4e6ba47bbc8a6680c0581d1bbd8ae0fbbf1e3 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Mon, 14 Oct 2024 18:42:15 +0000 Subject: [PATCH 1/4] Initial ExplorationWorker commit. ExplorationWorker implements the interactive clang compilation together with exploration for a given module. This commit adds the class with its compile_module function which will compile the module with a given compiler policy. --- compiler_opt/rl/generate_bc_trajectories.py | 239 +++++++++++++++++- .../rl/generate_bc_trajectories_test.py | 193 ++++++++++++++ 2 files changed, 431 insertions(+), 1 deletion(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 79d65f2d..977ae15c 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -14,12 +14,101 @@ # limitations under the License. """Module for running compilation and collect data for behavior cloning.""" -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +from absl import logging +import os +import shutil import numpy as np import tensorflow as tf from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step +from tf_agents.specs import tensor_spec + +from compiler_opt.distributed import worker +from compiler_opt.rl import corpus +from compiler_opt.rl import env + + +def add_int_feature( + sequence_example: tf.train.SequenceExample, + feature_value: np.int64, + feature_name: str, +): + """Add an int feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one belonging to + the instance + feature_value: tf.int64 value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.int64_list.value + lst.extend([feature_value]) + + +def add_float_feature( + sequence_example: tf.train.SequenceExample, + feature_value: np.float32, + feature_name: str, +): + """Add a float feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one belonging to + the instance + feature_value: tf.int64 value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.float_list.value + lst.extend([feature_value]) + + +def add_string_feature( + sequence_example: tf.train.SequenceExample, + feature_value: str, + feature_name: str, +): + """Add a string feature to feature list. + + Args: + sequence_example: sequence example to use instead of the one + feature_value: tf.string value of feature + feature_name: name of feature + """ + f = sequence_example.feature_lists.feature_list[feature_name].feature.add() + lst = f.bytes_list.value + lst.extend([feature_value.encode('utf-8')]) + + +def add_feature_list(seq_example: tf.train.SequenceExample, + feature_list: List[Any], feature_name: str): + """Add the feature_list to the sequence example under feature name. + + Args: + seq_example: sequence example to update + feature_list: list of feature values to add to seq_example + feature_name: name of the feature to add the list under + """ + if (type(feature_list[0]) not in [ + np.dtype(np.int64), + np.dtype(np.float32), + str, + ]): + raise AssertionError(f'''Unsupported type for feautre {0} of type {1}. + Supported types are np.int64, np.float32, str'''.format( + feature_name, type(feature_list[0]))) + if isinstance(feature_list[0], np.float32): + add_function = add_float_feature + elif isinstance(feature_list[0], (int, np.int64)): + add_function = add_int_feature + else: + add_function = add_string_feature + for feature in feature_list: + add_function(seq_example, feature, feature_name) class ExplorationWithPolicy: @@ -102,3 +191,151 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray: break self.curr_step += 1 return policy_action + + +class ExplorationWorker(worker.Worker): + """Class which implements the exploration for the given module. + + Attributes: + loaded_module_spec: the module to be compiled and explored + use_greedy: indicates if the default/greedy policy is used to compile the + module + env: MLGO environment. + exploration_frac: how often to explore in a trajectory + max_exploration_steps: maximum number of exploration steps + exploration_policy_distr: distribution function from exploration policy. + reward_key: which reward binary to use, must be specified as part of + additional task args (kwargs). + """ + + def __init__( + self, + loaded_module_spec: corpus.LoadedModuleSpec, + clang_path: str, + mlgo_task: Type[env.MLGOTask], + use_greedy: bool = False, + exploration_frac: float = 1.0, + max_exploration_steps: int = 10, + exploration_policy_distr=None, + obs_action_specs: Optional[Tuple[time_step.TimeStep, + tensor_spec.BoundedTensorSpec,]] = None, + **kwargs, + ): + self.loaded_module_spec = loaded_module_spec + self.use_greedy = use_greedy + if not obs_action_specs: + obs_spec = None + action_spec = None + else: + obs_spec = obs_action_specs[0].observation + action_spec = obs_action_specs[1] + + if 'reward_key' not in kwargs: + raise KeyError( + 'reward_key not specified in ExplorationWorker initialization.') + self.reward_key = kwargs['reward_key'] + kwargs.pop('reward_key', None) + self.working_dir = None + + self.env = env.MLGOEnvironmentBase( + clang_path=clang_path, + task_type=mlgo_task, + obs_spec=obs_spec, + action_spec=action_spec, + ) + if self.env.action_spec: + if self.env.action_spec.dtype != tf.int64: + raise TypeError( + f'Environment action_spec type {0} does not match tf.int64'.format( + self.env.action_spec.dtype)) + self.exploration_frac = exploration_frac + self.max_exploration_steps = max_exploration_steps + self.exploration_policy_distr = exploration_policy_distr + logging.info('Reward key in exploration worker: %s', self.reward_key) + + def compile_module( + self, + policy: Callable[[Optional[time_step.TimeStep]], np.ndarray], + ) -> tf.train.SequenceExample: + """Compiles the module with the given policy and outputs a seq. example. + + Args: + policy: policy to compile with + + Returns: + sequence_example: a tf.train.SequenceExample containing the trajectory + from compilation. + """ + sequence_example = tf.train.SequenceExample() + curr_obs_dict = self.env.reset(self.loaded_module_spec) + try: + curr_obs = curr_obs_dict.obs + self._process_obs(curr_obs, sequence_example) + while curr_obs_dict.step_type != env.StepType.LAST: + timestep = self._create_timestep(curr_obs_dict) + action = policy(timestep) + add_int_feature(sequence_example, int(action), 'action') + curr_obs_dict = self.env.step(action) + curr_obs = curr_obs_dict.obs + if curr_obs_dict.step_type == env.StepType.LAST: + break + self._process_obs(curr_obs, sequence_example) + except AssertionError as e: + logging.error('AssertionError: %s', e) + horizon = len(sequence_example.feature_lists.feature_list['action'].feature) + self.working_dir = curr_obs_dict.working_dir + if horizon <= 0: + working_dir_head = os.path.split(self.working_dir)[0] + shutil.rmtree(working_dir_head) + if horizon <= 0: + raise ValueError( + f'Policy did not take any inlining decision for module {0}.'.format( + self.loaded_module_spec.name)) + if curr_obs_dict.step_type != env.StepType.LAST: + raise ValueError( + f'Compilation loop exited at step type {0} before last step'.format( + curr_obs_dict.step_type)) + native_size = curr_obs_dict.score_policy[self.reward_key] + native_size_list = np.float32(native_size) * np.float32(np.ones(horizon)) + add_feature_list(sequence_example, native_size_list, 'reward') + module_name_list = [self.loaded_module_spec.name for _ in range(horizon)] + add_feature_list(sequence_example, module_name_list, 'module_name') + return sequence_example + + def _create_timestep(self, curr_obs_dict: env.TimeStep): + curr_obs = curr_obs_dict.obs + curr_obs_step = curr_obs_dict.step_type + step_type_converter = { + env.StepType.FIRST: 0, + env.StepType.MID: 1, + env.StepType.LAST: 2, + } + if curr_obs_dict.step_type == env.StepType.LAST: + reward = np.array(curr_obs_dict.score_policy[self.reward_key]) + else: + reward = np.array(0.) + curr_obs_step = step_type_converter[curr_obs_step] + timestep = time_step.TimeStep( + step_type=tf.convert_to_tensor([curr_obs_step], + dtype=tf.int32, + name='step_type'), + reward=tf.convert_to_tensor([reward], dtype=tf.float32, name='reward'), + discount=tf.convert_to_tensor([0.0], dtype=tf.float32, name='discount'), + observation=curr_obs, + ) + return timestep + + def _process_obs(self, curr_obs, sequence_example): + for curr_obs_feature_name in curr_obs: + if not self.env.obs_spec: + obs_dtype = tf.int64 + else: + if curr_obs_feature_name not in self.env.obs_spec.keys(): + raise AssertionError(f'Feature name {0} not in obs_spec {1}'.format( + curr_obs_feature_name, self.env.obs_spec.keys())) + obs_dtype = self.env.obs_spec[curr_obs_feature_name].dtype + curr_obs_feature = curr_obs[curr_obs_feature_name] + curr_obs[curr_obs_feature_name] = tf.convert_to_tensor( + curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name) + add_feature_list(sequence_example, curr_obs_feature, + curr_obs_feature_name) diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py index a4ae9b6c..5a277b58 100644 --- a/compiler_opt/rl/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -15,6 +15,7 @@ """Tests for compiler_opt.rl.generate_bc_trajectories.""" from typing import List +from unittest import mock import numpy as np import tensorflow as tf @@ -22,7 +23,11 @@ from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step +from google.protobuf import text_format + from compiler_opt.rl import generate_bc_trajectories +from compiler_opt.rl import env +from compiler_opt.rl import env_test _eps = 1e-5 @@ -152,3 +157,191 @@ def explore_on_feature_2_val(feature_val): for state in _get_state_list(): _ = explore_with_policy.get_advice(state)[0] self.assertEqual(1, explore_with_policy.explore_step) + + +class AddToFeatureListsTest(tf.test.TestCase): + + def test_add_int_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { int64_list { value: 1 } } + feature { int64_list { value: 2 } } + } + } + feature_list { + key: "feature_1" + value { + feature { int64_list { value: 3 } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=1, + feature_name='feature_0') + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=2, + feature_name='feature_0') + generate_bc_trajectories.add_int_feature( + sequence_example=sequence_example, + feature_value=3, + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + def test_add_float_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { float_list { value: .1 } } + feature { float_list { value: .2 } } + } + } + feature_list { + key: "feature_1" + value { + feature { float_list { value: .3 } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.1, + feature_name='feature_0') + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.2, + feature_name='feature_0') + generate_bc_trajectories.add_float_feature( + sequence_example=sequence_example, + feature_value=.3, + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + def test_add_string_feature(self): + sequence_example_text = """ + feature_lists { + feature_list { + key: "feature_0" + value { + feature { bytes_list { value: "1" } } + feature { bytes_list { value: "2" } } + } + } + feature_list { + key: "feature_1" + value { + feature { bytes_list { value: "3" } } + } + } + }""" + sequence_example_comp = text_format.Parse(sequence_example_text, + tf.train.SequenceExample()) + + sequence_example = tf.train.SequenceExample() + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='1', + feature_name='feature_0') + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='2', + feature_name='feature_0') + generate_bc_trajectories.add_string_feature( + sequence_example=sequence_example, + feature_value='3', + feature_name='feature_1') + + self.assertEqual(sequence_example, sequence_example_comp) + + +class ExplorationWorkerTest(tf.test.TestCase): + # pylint: disable=protected-access + @mock.patch('subprocess.Popen') + def test_create_timestep(self, mock_popen): + mock_popen.side_effect = env_test.mock_interactive_clang + + def create_timestep_comp(step_type, reward, obs): + timestep_comp = time_step.TimeStep( + step_type=tf.convert_to_tensor([step_type], + dtype=tf.int32, + name='step_type'), + reward=tf.convert_to_tensor([reward], dtype=tf.float32, + name='reward'), + discount=tf.convert_to_tensor([0.0], + dtype=tf.float32, + name='discount'), + observation=obs, + ) + return timestep_comp + + test_env = env.MLGOEnvironmentBase( + clang_path=env_test._CLANG_PATH, + task_type=env_test.MockTask, + obs_spec={}, + action_spec={}, + ) + + exploration_worker = generate_bc_trajectories.ExplorationWorker( + loaded_module_spec=env_test._MOCK_MODULE, + clang_path=env_test._CLANG_PATH, + mlgo_task=env_test.MockTask, + reward_key='default', + ) + + curr_step_obs = test_env.reset(env_test._MOCK_MODULE) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(0, 0., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + for step_itr in range(env_test._NUM_STEPS - 1): + del step_itr + curr_step_obs = test_env.step(np.array([1], dtype=np.int64)) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(1, 0., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + curr_step_obs = test_env.step(np.array([1], dtype=np.int64)) + timestep = exploration_worker._create_timestep(curr_step_obs) + timestep_comp = create_timestep_comp(2, 47., curr_step_obs.obs) + self.assertEqual(timestep, timestep_comp) + + @mock.patch('subprocess.Popen') + def test_compile_module(self, mock_popen): + mock_popen.side_effect = env_test.mock_interactive_clang + + seq_example_comp = tf.train.SequenceExample() + for i in range(10): + generate_bc_trajectories.add_int_feature(seq_example_comp, i, + 'times_called') + generate_bc_trajectories.add_string_feature(seq_example_comp, 'module', + 'module_name') + generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0, + 'reward') + generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5), + 'action') + + exploration_worker = generate_bc_trajectories.ExplorationWorker( + loaded_module_spec=env_test._MOCK_MODULE, + clang_path=env_test._CLANG_PATH, + mlgo_task=env_test.MockTask, + reward_key='default', + ) + + seq_example = exploration_worker.compile_module(_policy) + self.assertEqual(seq_example, seq_example_comp) From a2d39b2f9258aae20e6dc7c6dfede28244b9daf0 Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Mon, 14 Oct 2024 18:53:55 +0000 Subject: [PATCH 2/4] Suppressing pytype error for protobuf import. --- compiler_opt/rl/generate_bc_trajectories_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py index 5a277b58..aae60f7b 100644 --- a/compiler_opt/rl/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -23,7 +23,7 @@ from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step -from google.protobuf import text_format +from google.protobuf import text_format # pytype: disable=pyi-error from compiler_opt.rl import generate_bc_trajectories from compiler_opt.rl import env From 59221f104a47af8e316aa34746717cec782779ec Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Wed, 16 Oct 2024 22:37:37 +0000 Subject: [PATCH 3/4] Addressing comments by @mtrofin. --- compiler_opt/rl/generate_bc_trajectories.py | 96 +++++++++++++-------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 977ae15c..90727232 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type from absl import logging +import dataclasses import os import shutil @@ -41,7 +42,8 @@ def add_int_feature( Args: sequence_example: sequence example to use instead of the one belonging to the instance - feature_value: tf.int64 value of feature + feature_value: np.int64 value of feature, this is the required type by + tf.train.SequenceExample for an int list feature_name: name of feature """ f = sequence_example.feature_lists.feature_list[feature_name].feature.add() @@ -59,7 +61,8 @@ def add_float_feature( Args: sequence_example: sequence example to use instead of the one belonging to the instance - feature_value: tf.int64 value of feature + feature_value: np.float32 value of feature, this is the required type by + tf.train.SequenceExample for an float list feature_name: name of feature """ f = sequence_example.feature_lists.feature_list[feature_name].feature.add() @@ -111,6 +114,13 @@ def add_feature_list(seq_example: tf.train.SequenceExample, add_function(seq_example, feature, feature_name) +@dataclasses.dataclass +class SequenceExampleFeatureNames: + action: str = 'action' + reward: str = 'reward' + module_name: str = 'module_name' + + class ExplorationWithPolicy: """Policy which selects states for exploration. @@ -219,10 +229,11 @@ def __init__( exploration_policy_distr=None, obs_action_specs: Optional[Tuple[time_step.TimeStep, tensor_spec.BoundedTensorSpec,]] = None, + reward_key: str = '', **kwargs, ): - self.loaded_module_spec = loaded_module_spec - self.use_greedy = use_greedy + self._loaded_module_spec = loaded_module_spec + self._use_greedy = use_greedy if not obs_action_specs: obs_spec = None action_spec = None @@ -230,28 +241,28 @@ def __init__( obs_spec = obs_action_specs[0].observation action_spec = obs_action_specs[1] - if 'reward_key' not in kwargs: - raise KeyError( + if reward_key == '': + raise TypeError( 'reward_key not specified in ExplorationWorker initialization.') - self.reward_key = kwargs['reward_key'] + self._reward_key = reward_key kwargs.pop('reward_key', None) - self.working_dir = None + self._working_dir = None - self.env = env.MLGOEnvironmentBase( + self._env = env.MLGOEnvironmentBase( clang_path=clang_path, task_type=mlgo_task, obs_spec=obs_spec, action_spec=action_spec, ) - if self.env.action_spec: - if self.env.action_spec.dtype != tf.int64: + if self._env.action_spec: + if self._env.action_spec.dtype != tf.int64: raise TypeError( f'Environment action_spec type {0} does not match tf.int64'.format( - self.env.action_spec.dtype)) - self.exploration_frac = exploration_frac - self.max_exploration_steps = max_exploration_steps - self.exploration_policy_distr = exploration_policy_distr - logging.info('Reward key in exploration worker: %s', self.reward_key) + self._env.action_spec.dtype)) + self._exploration_frac = exploration_frac + self._max_exploration_steps = max_exploration_steps + self._exploration_policy_distr = exploration_policy_distr + logging.info('Reward key in exploration worker: %s', self._reward_key) def compile_module( self, @@ -264,42 +275,51 @@ def compile_module( Returns: sequence_example: a tf.train.SequenceExample containing the trajectory - from compilation. + from compilation. In addition to the features returned from the env + tbe sequence_example adds the following extra features: action, + reward and module_name. action is the action taken at any given step, + reward is the reward specified by reward_key, not necessarily the + reward returned by the environment and module_name is the name of + the module processed by the compiler. """ sequence_example = tf.train.SequenceExample() - curr_obs_dict = self.env.reset(self.loaded_module_spec) + curr_obs_dict = self._env.reset(self._loaded_module_spec) try: curr_obs = curr_obs_dict.obs self._process_obs(curr_obs, sequence_example) while curr_obs_dict.step_type != env.StepType.LAST: timestep = self._create_timestep(curr_obs_dict) action = policy(timestep) - add_int_feature(sequence_example, int(action), 'action') - curr_obs_dict = self.env.step(action) + add_int_feature(sequence_example, int(action), + SequenceExampleFeatureNames.action) + curr_obs_dict = self._env.step(action) curr_obs = curr_obs_dict.obs if curr_obs_dict.step_type == env.StepType.LAST: break self._process_obs(curr_obs, sequence_example) except AssertionError as e: logging.error('AssertionError: %s', e) - horizon = len(sequence_example.feature_lists.feature_list['action'].feature) - self.working_dir = curr_obs_dict.working_dir + horizon = len(sequence_example.feature_lists.feature_list[ + SequenceExampleFeatureNames.action].feature) + self._working_dir = curr_obs_dict.working_dir if horizon <= 0: - working_dir_head = os.path.split(self.working_dir)[0] + working_dir_head = os.path.split(self._working_dir)[0] shutil.rmtree(working_dir_head) if horizon <= 0: raise ValueError( f'Policy did not take any inlining decision for module {0}.'.format( - self.loaded_module_spec.name)) + self._loaded_module_spec.name)) if curr_obs_dict.step_type != env.StepType.LAST: raise ValueError( f'Compilation loop exited at step type {0} before last step'.format( curr_obs_dict.step_type)) - native_size = curr_obs_dict.score_policy[self.reward_key] - native_size_list = np.float32(native_size) * np.float32(np.ones(horizon)) - add_feature_list(sequence_example, native_size_list, 'reward') - module_name_list = [self.loaded_module_spec.name for _ in range(horizon)] - add_feature_list(sequence_example, module_name_list, 'module_name') + reward = curr_obs_dict.score_policy[self._reward_key] + reward_list = np.float32(reward) * np.float32(np.ones(horizon)) + add_feature_list(sequence_example, reward_list, + SequenceExampleFeatureNames.reward) + module_name_list = [self._loaded_module_spec.name for _ in range(horizon)] + add_feature_list(sequence_example, module_name_list, + SequenceExampleFeatureNames.module_name) return sequence_example def _create_timestep(self, curr_obs_dict: env.TimeStep): @@ -311,7 +331,7 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep): env.StepType.LAST: 2, } if curr_obs_dict.step_type == env.StepType.LAST: - reward = np.array(curr_obs_dict.score_policy[self.reward_key]) + reward = np.array(curr_obs_dict.score_policy[self._reward_key]) else: reward = np.array(0.) curr_obs_step = step_type_converter[curr_obs_step] @@ -327,13 +347,21 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep): def _process_obs(self, curr_obs, sequence_example): for curr_obs_feature_name in curr_obs: - if not self.env.obs_spec: + if not self._env.obs_spec: obs_dtype = tf.int64 else: - if curr_obs_feature_name not in self.env.obs_spec.keys(): + if curr_obs_feature_name not in self._env.obs_spec.keys(): raise AssertionError(f'Feature name {0} not in obs_spec {1}'.format( - curr_obs_feature_name, self.env.obs_spec.keys())) - obs_dtype = self.env.obs_spec[curr_obs_feature_name].dtype + curr_obs_feature_name, self._env.obs_spec.keys())) + if curr_obs_feature_name in [ + SequenceExampleFeatureNames.action, + SequenceExampleFeatureNames.reward, + SequenceExampleFeatureNames.module_name + ]: + raise AssertionError( + f'Feature name {0} already part of SequenceExampleFeatureNames' + .format(curr_obs_feature_name, self._env.obs_spec.keys())) + obs_dtype = self._env.obs_spec[curr_obs_feature_name].dtype curr_obs_feature = curr_obs[curr_obs_feature_name] curr_obs[curr_obs_feature_name] = tf.convert_to_tensor( curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name) From 0474b603d66ed95ebf10c4ed42619974acdf2a1c Mon Sep 17 00:00:00 2001 From: "Teodor V. Marinov" Date: Thu, 17 Oct 2024 00:34:12 +0000 Subject: [PATCH 4/4] Fixing nits. --- compiler_opt/rl/generate_bc_trajectories.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py index 90727232..49f21fe4 100644 --- a/compiler_opt/rl/generate_bc_trajectories.py +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -116,6 +116,7 @@ def add_feature_list(seq_example: tf.train.SequenceExample, @dataclasses.dataclass class SequenceExampleFeatureNames: + """Feature names for features that are always added to seq example.""" action: str = 'action' reward: str = 'reward' module_name: str = 'module_name'