Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial ExplorationWorker commit #383

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from absl import logging
import dataclasses
import os
import shutil

Expand All @@ -41,7 +42,8 @@ def add_int_feature(
Args:
sequence_example: sequence example to use instead of the one belonging to
the instance
feature_value: tf.int64 value of feature
feature_value: np.int64 value of feature, this is the required type by
tf.train.SequenceExample for an int list
feature_name: name of feature
"""
f = sequence_example.feature_lists.feature_list[feature_name].feature.add()
Expand All @@ -59,7 +61,8 @@ def add_float_feature(
Args:
sequence_example: sequence example to use instead of the one belonging to
the instance
feature_value: tf.int64 value of feature
feature_value: np.float32 value of feature, this is the required type by
tf.train.SequenceExample for an float list
feature_name: name of feature
"""
f = sequence_example.feature_lists.feature_list[feature_name].feature.add()
Expand Down Expand Up @@ -111,6 +114,13 @@ def add_feature_list(seq_example: tf.train.SequenceExample,
add_function(seq_example, feature, feature_name)


@dataclasses.dataclass
class SequenceExampleFeatureNames:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment or doc string saying what this is for ("feature names that we depend on" for example).

also use these in the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

action: str = 'action'
reward: str = 'reward'
module_name: str = 'module_name'


class ExplorationWithPolicy:
"""Policy which selects states for exploration.

Expand Down Expand Up @@ -219,39 +229,40 @@ def __init__(
exploration_policy_distr=None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
reward_key: str = '',
**kwargs,
):
self.loaded_module_spec = loaded_module_spec
self.use_greedy = use_greedy
self._loaded_module_spec = loaded_module_spec
self._use_greedy = use_greedy
if not obs_action_specs:
obs_spec = None
action_spec = None
else:
obs_spec = obs_action_specs[0].observation
action_spec = obs_action_specs[1]

if 'reward_key' not in kwargs:
raise KeyError(
if reward_key == '':
raise TypeError(
'reward_key not specified in ExplorationWorker initialization.')
self.reward_key = kwargs['reward_key']
self._reward_key = reward_key
kwargs.pop('reward_key', None)
self.working_dir = None
self._working_dir = None

self.env = env.MLGOEnvironmentBase(
self._env = env.MLGOEnvironmentBase(
clang_path=clang_path,
task_type=mlgo_task,
obs_spec=obs_spec,
action_spec=action_spec,
)
if self.env.action_spec:
if self.env.action_spec.dtype != tf.int64:
if self._env.action_spec:
if self._env.action_spec.dtype != tf.int64:
raise TypeError(
f'Environment action_spec type {0} does not match tf.int64'.format(
self.env.action_spec.dtype))
self.exploration_frac = exploration_frac
self.max_exploration_steps = max_exploration_steps
self.exploration_policy_distr = exploration_policy_distr
logging.info('Reward key in exploration worker: %s', self.reward_key)
self._env.action_spec.dtype))
self._exploration_frac = exploration_frac
self._max_exploration_steps = max_exploration_steps
self._exploration_policy_distr = exploration_policy_distr
logging.info('Reward key in exploration worker: %s', self._reward_key)

def compile_module(
self,
Expand All @@ -264,42 +275,51 @@ def compile_module(

Returns:
sequence_example: a tf.train.SequenceExample containing the trajectory
from compilation.
from compilation. In addition to the features returned from the env
tbe sequence_example adds the following extra features: action,
reward and module_name. action is the action taken at any given step,
reward is the reward specified by reward_key, not necessarily the
reward returned by the environment and module_name is the name of
the module processed by the compiler.
"""
sequence_example = tf.train.SequenceExample()
curr_obs_dict = self.env.reset(self.loaded_module_spec)
curr_obs_dict = self._env.reset(self._loaded_module_spec)
try:
curr_obs = curr_obs_dict.obs
self._process_obs(curr_obs, sequence_example)
while curr_obs_dict.step_type != env.StepType.LAST:
timestep = self._create_timestep(curr_obs_dict)
action = policy(timestep)
add_int_feature(sequence_example, int(action), 'action')
curr_obs_dict = self.env.step(action)
add_int_feature(sequence_example, int(action),
SequenceExampleFeatureNames.action)
curr_obs_dict = self._env.step(action)
curr_obs = curr_obs_dict.obs
if curr_obs_dict.step_type == env.StepType.LAST:
break
self._process_obs(curr_obs, sequence_example)
except AssertionError as e:
logging.error('AssertionError: %s', e)
horizon = len(sequence_example.feature_lists.feature_list['action'].feature)
self.working_dir = curr_obs_dict.working_dir
horizon = len(sequence_example.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
self._working_dir = curr_obs_dict.working_dir
if horizon <= 0:
working_dir_head = os.path.split(self.working_dir)[0]
working_dir_head = os.path.split(self._working_dir)[0]
shutil.rmtree(working_dir_head)
if horizon <= 0:
raise ValueError(
f'Policy did not take any inlining decision for module {0}.'.format(
self.loaded_module_spec.name))
self._loaded_module_spec.name))
if curr_obs_dict.step_type != env.StepType.LAST:
raise ValueError(
f'Compilation loop exited at step type {0} before last step'.format(
curr_obs_dict.step_type))
native_size = curr_obs_dict.score_policy[self.reward_key]
native_size_list = np.float32(native_size) * np.float32(np.ones(horizon))
add_feature_list(sequence_example, native_size_list, 'reward')
module_name_list = [self.loaded_module_spec.name for _ in range(horizon)]
add_feature_list(sequence_example, module_name_list, 'module_name')
reward = curr_obs_dict.score_policy[self._reward_key]
reward_list = np.float32(reward) * np.float32(np.ones(horizon))
add_feature_list(sequence_example, reward_list,
SequenceExampleFeatureNames.reward)
module_name_list = [self._loaded_module_spec.name for _ in range(horizon)]
add_feature_list(sequence_example, module_name_list,
SequenceExampleFeatureNames.module_name)
return sequence_example

def _create_timestep(self, curr_obs_dict: env.TimeStep):
Expand All @@ -311,7 +331,7 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep):
env.StepType.LAST: 2,
}
if curr_obs_dict.step_type == env.StepType.LAST:
reward = np.array(curr_obs_dict.score_policy[self.reward_key])
reward = np.array(curr_obs_dict.score_policy[self._reward_key])
else:
reward = np.array(0.)
curr_obs_step = step_type_converter[curr_obs_step]
Expand All @@ -327,13 +347,21 @@ def _create_timestep(self, curr_obs_dict: env.TimeStep):

def _process_obs(self, curr_obs, sequence_example):
for curr_obs_feature_name in curr_obs:
if not self.env.obs_spec:
if not self._env.obs_spec:
obs_dtype = tf.int64
else:
if curr_obs_feature_name not in self.env.obs_spec.keys():
if curr_obs_feature_name not in self._env.obs_spec.keys():
raise AssertionError(f'Feature name {0} not in obs_spec {1}'.format(
curr_obs_feature_name, self.env.obs_spec.keys()))
obs_dtype = self.env.obs_spec[curr_obs_feature_name].dtype
curr_obs_feature_name, self._env.obs_spec.keys()))
if curr_obs_feature_name in [
SequenceExampleFeatureNames.action,
SequenceExampleFeatureNames.reward,
SequenceExampleFeatureNames.module_name
]:
raise AssertionError(
f'Feature name {0} already part of SequenceExampleFeatureNames'
.format(curr_obs_feature_name, self._env.obs_spec.keys()))
obs_dtype = self._env.obs_spec[curr_obs_feature_name].dtype
curr_obs_feature = curr_obs[curr_obs_feature_name]
curr_obs[curr_obs_feature_name] = tf.convert_to_tensor(
curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name)
Expand Down