Skip to content

Commit

Permalink
exploration_worker.explore_function commit.
Browse files Browse the repository at this point in the history
Explores the module using the given policy and the exploration distr.
This implementation assumes the action set is only {0,1} and will just
switch the action played at explore_step.
  • Loading branch information
tvmarino committed Oct 18, 2024
1 parent 3d18cb1 commit 6aa7c82
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 50 deletions.
221 changes: 188 additions & 33 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
"""Module for running compilation and collect data for behavior cloning."""

from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator

from absl import logging
import dataclasses
import os
import shutil

import math
import numpy as np
import tensorflow as tf
from tf_agents.trajectories import policy_step
Expand All @@ -32,6 +33,30 @@
from compiler_opt.rl import env


@dataclasses.dataclass
class SequenceExampleFeatureNames:
"""Feature names for features that are always added to seq example."""
action: str = 'action'
reward: str = 'reward'
module_name: str = 'module_name'


def get_loss(seq_example: tf.train.SequenceExample,
reward_key: str = SequenceExampleFeatureNames.reward) -> int:
"""Return the last loss/reward of a trajectory written in a SequenceExample.
Args:
seq_example: tf.train.SequenceExample which contains the trajectory with
all features, including a reward feature
reward_key: the name of the feature that contains the loss/reward.
Returns:
The loss/reward of a trajectory written in a SequenceExample.
"""
return (seq_example.feature_lists.feature_list[reward_key].feature[-1]
.float_list.value[0])


def add_int_feature(
sequence_example: tf.train.SequenceExample,
feature_value: np.int64,
Expand Down Expand Up @@ -114,14 +139,6 @@ def add_feature_list(seq_example: tf.train.SequenceExample,
add_function(seq_example, feature, feature_name)


@dataclasses.dataclass
class SequenceExampleFeatureNames:
"""Feature names for features that are always added to seq example."""
action: str = 'action'
reward: str = 'reward'
module_name: str = 'module_name'


class ExplorationWithPolicy:
"""Policy which selects states for exploration.
Expand All @@ -140,6 +157,7 @@ class ExplorationWithPolicy:
explore_policy: randomized policy which is used to compute the gap
curr_step: current step of the trajectory
explore_step: current candidate for exploration step
explore_state: current candidate state for exploration at explore_step
gap: current difference at explore step between probability of most likely
action according to explore_policy and second most likely action
explore_on_features: dict of feature names and functions which specify
Expand All @@ -154,13 +172,15 @@ def __init__(
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
):
self.replay_prefix = replay_prefix
self.policy = policy
self.explore_policy = explore_policy
self.curr_step = 0
self.explore_step = 0
self.gap = np.inf
self.explore_on_features = explore_on_features
self.explore_step: int = len(replay_prefix) - 1
self.explore_state: time_step.TimeStep | None = None
self._replay_prefix = replay_prefix
self._policy = policy
self._explore_policy = explore_policy
self._curr_step = 0
self._gap = np.inf
self._explore_on_features: Optional[Dict[str, Callable[
[tf.Tensor], bool]]] = explore_on_features
self._stop_exploration = False

def _compute_gap(self, distr: np.ndarray) -> np.float32:
Expand All @@ -179,28 +199,29 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
policy_action: action to take at the current state.
"""
if self.curr_step < len(self.replay_prefix):
self.curr_step += 1
return np.array(self.replay_prefix[self.curr_step - 1])
policy_action = self.policy(state)
if self._curr_step < len(self._replay_prefix):
self._curr_step += 1
return np.array(self._replay_prefix[self._curr_step - 1])
policy_action = self._policy(state)
# explore_policy(state) should play at least one action per state and so
# self.explore_policy(state).action.logits should have at least one entry
distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0]
distr = tf.nn.softmax(self._explore_policy(state).action.logits).numpy()[0]
curr_gap = self._compute_gap(distr)
# selecting explore_step is done based on smallest encountered gap in the
# play of self.policy. This logic can be changed to have different type
# of exploration.
if (not self._stop_exploration and distr.shape[0] > 1 and
self.gap > curr_gap):
self.gap = curr_gap
self.explore_step = self.curr_step
if not self._stop_exploration and self.explore_on_features is not None:
for feature_name, explore_on_feature in self.explore_on_features.items():
self._gap > curr_gap):
self._gap = curr_gap
self.explore_step = self._curr_step
self.explore_state = state
if not self._stop_exploration and self._explore_on_features is not None:
for feature_name, explore_on_feature in self._explore_on_features.items():
if explore_on_feature(state.observation[feature_name]):
self.explore_step = self.curr_step
self.explore_step = self._curr_step
self._stop_exploration = True
break
self.curr_step += 1
self._curr_step += 1
return policy_action


Expand All @@ -214,7 +235,10 @@ class ExplorationWorker(worker.Worker):
env: MLGO environment.
exploration_frac: how often to explore in a trajectory
max_exploration_steps: maximum number of exploration steps
exploration_policy_distr: distribution function from exploration policy.
max_horizon_to_explore: if the horizon under policy is greater than this
we do not do exploration
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
reward_key: which reward binary to use, must be specified as part of
additional task args (kwargs).
"""
Expand All @@ -224,17 +248,17 @@ def __init__(
loaded_module_spec: corpus.LoadedModuleSpec,
clang_path: str,
mlgo_task: Type[env.MLGOTask],
use_greedy: bool = False,
exploration_frac: float = 1.0,
max_exploration_steps: int = 10,
exploration_policy_distr=None,
max_horizon_to_explore=np.inf,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
reward_key: str = '',
**kwargs,
):
self._loaded_module_spec = loaded_module_spec
self._use_greedy = use_greedy
if not obs_action_specs:
obs_spec = None
action_spec = None
Expand Down Expand Up @@ -262,7 +286,8 @@ def __init__(
self._env.action_spec.dtype))
self._exploration_frac = exploration_frac
self._max_exploration_steps = max_exploration_steps
self._exploration_policy_distr = exploration_policy_distr
self._max_horizon_to_explore = max_horizon_to_explore
self._explore_on_features = explore_on_features
logging.info('Reward key in exploration worker: %s', self._reward_key)

def compile_module(
Expand Down Expand Up @@ -323,6 +348,136 @@ def compile_module(
SequenceExampleFeatureNames.module_name)
return sequence_example

def explore_function(
self,
policy: Callable[[Optional[time_step.TimeStep]], np.ndarray],
explore_policy: Optional[Callable[[time_step.TimeStep],
policy_step.PolicyStep]] = None,
) -> Tuple[List[tf.train.SequenceExample], List[str], int, float]:
"""Explores the module using the given policy and the exploration distr.
Args:
policy: policy which acts on all states outside of the exploration states.
explore_policy: randomized policy which is used to compute the gap for
exploration and can be used for deciding which actions to explore at
the exploration state.
Returns:
seq_example_list: a tf.train.SequenceExample list containing the all
trajectories from exploration.
working_dir_names: the directories of the compiled binaries
loss_idx: idx of the smallest loss trajectory in the seq_example_list.
base_seq_loss: loss of the trajectory compiled with policy.
"""
seq_example_list = []
working_dir_names = []
loss_idx = 0
exploration_steps = 0

if not explore_policy:
base_seq = self.compile_module(policy)
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
return (
seq_example_list,
working_dir_names,
loss_idx,
get_loss(base_seq),
)

base_policy = ExplorationWithPolicy(
[],
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
base_seq_loss = get_loss(base_seq)
horizon = len(base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
num_states = int(math.ceil(self._exploration_frac * horizon))
num_states = min(num_states, self._max_exploration_steps)
if num_states < 1 or horizon > self._max_horizon_to_explore:
return seq_example_list, working_dir_names, loss_idx, base_seq_loss

seq_losses = [base_seq_loss]
for num_steps in range(num_states):
explore_step = base_policy.explore_step
if explore_step >= horizon:
break
replay_prefix = base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature
replay_prefix = self._build_replay_prefix_list(
replay_prefix[:explore_step + 1])
explore_state = base_policy.explore_state
for base_seq, base_policy in self.explore_at_state_generator(
replay_prefix, explore_step, explore_state, policy, explore_policy):
exploration_steps += 1
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
seq_loss = get_loss(base_seq)
seq_losses.append(seq_loss)
# <= biases towards more exploration in the dataset, < towards less expl
if seq_loss < base_seq_loss:
base_seq_loss = seq_loss
loss_idx = num_steps + 1
logging.info('module exploration losses: %s', seq_losses)
if exploration_steps > self._max_exploration_steps:
return seq_example_list, working_dir_names, loss_idx, base_seq_loss
horizon = len(base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
# check if we are at the end of the trajectory and the last was explored
if (explore_step == base_policy.explore_step and
explore_step == horizon - 1):
return seq_example_list, working_dir_names, loss_idx, base_seq_loss

return seq_example_list, working_dir_names, loss_idx, base_seq_loss

def explore_at_state_generator(
self, replay_prefix: List[np.ndarray], explore_step: int,
explore_state: time_step.TimeStep,
policy: Callable[[Optional[time_step.TimeStep]], np.ndarray],
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep]
) -> Generator[Tuple[tf.train.SequenceExample, ExplorationWithPolicy], None,
None]:
"""Generate sequence examples and next exploration policy while exploring.
Generator that defines how to explore at the given explore_step. This
implementation assumes the action set is only {0,1} and will just switch
the action played at explore_step.
Args:
replay_prefix: a replay buffer of actions
explore_step: exploration step in the previous compiled trajectory
explore_state: state for exploration at explore_step
policy: policy which acts on all states outside of the exploration states.
explore_policy: randomized policy which is used to compute the gap for
exploration and can be used for deciding which actions to explore at
the exploration state.
Yields:
base_seq: a tf.train.SequenceExample containing a compiled trajectory
base_policy: the policy used to determine the next exploration step
"""
del explore_state
replay_prefix[explore_step] = 1 - replay_prefix[explore_step]
base_policy = ExplorationWithPolicy(
replay_prefix,
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
yield base_seq, base_policy

def _build_replay_prefix_list(self, seq_ex):
ret_list = []
for int_list in seq_ex:
ret_list.append(int_list.int64_list.value[0])
return ret_list

def _create_timestep(self, curr_obs_dict: env.TimeStep):
curr_obs = curr_obs_dict.obs
curr_obs_step = curr_obs_dict.step_type
Expand Down
Loading

0 comments on commit 6aa7c82

Please sign in to comment.