Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

explore_function part of ExplorationWorker #384

Merged
merged 9 commits into from
Oct 21, 2024
169 changes: 163 additions & 6 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
"""Module for running compilation and collect data for behavior cloning."""

from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator

from absl import logging
import dataclasses
import os
import shutil

import math
import numpy as np
import tensorflow as tf
from tf_agents.trajectories import policy_step
Expand All @@ -40,6 +41,22 @@ class SequenceExampleFeatureNames:
module_name: str = 'module_name'


def get_loss(seq_example: tf.train.SequenceExample,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat, I wonder if we should move the sequence example stuff in its own utility library? if you think that's a good idea, can you tag an issue on this, we can do it after.

reward_key: str = SequenceExampleFeatureNames.reward) -> int:
"""Return the last loss/reward of a trajectory written in a SequenceExample.

Args:
seq_example: tf.train.SequenceExample which contains the trajectory with
all features, including a reward feature
reward_key: the name of the feature that contains the loss/reward.

Returns:
The loss/reward of a trajectory written in a SequenceExample.
"""
return (seq_example.feature_lists.feature_list[reward_key].feature[-1]
.float_list.value[0])


def add_int_feature(
sequence_example: tf.train.SequenceExample,
feature_value: np.int64,
Expand Down Expand Up @@ -140,6 +157,7 @@ class ExplorationWithPolicy:
explore_policy: randomized policy which is used to compute the gap
curr_step: current step of the trajectory
explore_step: current candidate for exploration step
explore_state: current candidate state for exploration at explore_step
gap: current difference at explore step between probability of most likely
action according to explore_policy and second most likely action
explore_on_features: dict of feature names and functions which specify
Expand All @@ -155,6 +173,7 @@ def __init__(
bool]]] = None,
):
self._explore_step: int = len(replay_prefix) - 1
self._explore_state: Optional[time_step.TimeStep] = None
self._replay_prefix = replay_prefix
self._policy = policy
self._explore_policy = explore_policy
Expand All @@ -173,6 +192,9 @@ def _compute_gap(self, distr: np.ndarray) -> np.float32:
def get_explore_step(self) -> int:
return self._explore_step

def get_explore_state(self) -> Optional[time_step.TimeStep]:
return self._explore_state

def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
"""Action function for the policy.

Expand All @@ -198,6 +220,7 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
self._gap > curr_gap):
self._gap = curr_gap
self._explore_step = self._curr_step
self._explore_state = state
if not self._stop_exploration and self._explore_on_features is not None:
for feature_name, explore_on_feature in self._explore_on_features.items():
if explore_on_feature(state.observation[feature_name]):
Expand All @@ -218,7 +241,10 @@ class ExplorationWorker(worker.Worker):
env: MLGO environment.
exploration_frac: how often to explore in a trajectory
max_exploration_steps: maximum number of exploration steps
exploration_policy_distr: distribution function from exploration policy.
max_horizon_to_explore: if the horizon under policy is greater than this
we do not do exploration
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
reward_key: which reward binary to use, must be specified as part of
additional task args (kwargs).
"""
Expand All @@ -228,17 +254,17 @@ def __init__(
loaded_module_spec: corpus.LoadedModuleSpec,
clang_path: str,
mlgo_task: Type[env.MLGOTask],
use_greedy: bool = False,
exploration_frac: float = 1.0,
max_exploration_steps: int = 10,
exploration_policy_distr=None,
max_horizon_to_explore=np.inf,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
reward_key: str = '',
**kwargs,
):
self._loaded_module_spec = loaded_module_spec
self._use_greedy = use_greedy
if not obs_action_specs:
obs_spec = None
action_spec = None
Expand Down Expand Up @@ -266,7 +292,8 @@ def __init__(
self._env.action_spec.dtype))
self._exploration_frac = exploration_frac
self._max_exploration_steps = max_exploration_steps
self._exploration_policy_distr = exploration_policy_distr
self._max_horizon_to_explore = max_horizon_to_explore
self._explore_on_features = explore_on_features
logging.info('Reward key in exploration worker: %s', self._reward_key)

def compile_module(
Expand Down Expand Up @@ -327,6 +354,136 @@ def compile_module(
SequenceExampleFeatureNames.module_name)
return sequence_example

def explore_function(
self,
policy: Callable[[Optional[time_step.TimeStep]], np.ndarray],
explore_policy: Optional[Callable[[time_step.TimeStep],
policy_step.PolicyStep]] = None,
) -> Tuple[List[tf.train.SequenceExample], List[str], int, float]:
"""Explores the module using the given policy and the exploration distr.

Args:
policy: policy which acts on all states outside of the exploration states.
explore_policy: randomized policy which is used to compute the gap for
exploration and can be used for deciding which actions to explore at
the exploration state.

Returns:
seq_example_list: a tf.train.SequenceExample list containing the all
trajectories from exploration.
working_dir_names: the directories of the compiled binaries
loss_idx: idx of the smallest loss trajectory in the seq_example_list.
base_seq_loss: loss of the trajectory compiled with policy.
"""
seq_example_list = []
working_dir_names = []
loss_idx = 0
exploration_steps = 0

if not explore_policy:
base_seq = self.compile_module(policy)
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
return (
seq_example_list,
working_dir_names,
loss_idx,
get_loss(base_seq),
)

base_policy = ExplorationWithPolicy(
[],
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
base_seq_loss = get_loss(base_seq)
horizon = len(base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
num_states = int(math.ceil(self._exploration_frac * horizon))
num_states = min(num_states, self._max_exploration_steps)
if num_states < 1 or horizon > self._max_horizon_to_explore:
return seq_example_list, working_dir_names, loss_idx, base_seq_loss

seq_losses = [base_seq_loss]
for num_steps in range(num_states):
explore_step = base_policy.get_explore_step()
if explore_step >= horizon:
break
replay_prefix = base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature
replay_prefix = self._build_replay_prefix_list(
replay_prefix[:explore_step + 1])
explore_state = base_policy.get_explore_state()
for base_seq, base_policy in self.explore_at_state_generator(
replay_prefix, explore_step, explore_state, policy, explore_policy):
exploration_steps += 1
seq_example_list.append(base_seq)
working_dir_names.append(self._working_dir)
seq_loss = get_loss(base_seq)
seq_losses.append(seq_loss)
# <= biases towards more exploration in the dataset, < towards less expl
if seq_loss < base_seq_loss:
base_seq_loss = seq_loss
loss_idx = num_steps + 1
logging.info('module exploration losses: %s', seq_losses)
if exploration_steps > self._max_exploration_steps:
return seq_example_list, working_dir_names, loss_idx, base_seq_loss
horizon = len(base_seq.feature_lists.feature_list[
SequenceExampleFeatureNames.action].feature)
# check if we are at the end of the trajectory and the last was explored
if (explore_step == base_policy.get_explore_step() and
explore_step == horizon - 1):
return seq_example_list, working_dir_names, loss_idx, base_seq_loss

return seq_example_list, working_dir_names, loss_idx, base_seq_loss

def explore_at_state_generator(
self, replay_prefix: List[np.ndarray], explore_step: int,
explore_state: time_step.TimeStep,
policy: Callable[[Optional[time_step.TimeStep]], np.ndarray],
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep]
) -> Generator[Tuple[tf.train.SequenceExample, ExplorationWithPolicy], None,
None]:
"""Generate sequence examples and next exploration policy while exploring.

Generator that defines how to explore at the given explore_step. This
implementation assumes the action set is only {0,1} and will just switch
the action played at explore_step.

Args:
replay_prefix: a replay buffer of actions
explore_step: exploration step in the previous compiled trajectory
explore_state: state for exploration at explore_step
policy: policy which acts on all states outside of the exploration states.
explore_policy: randomized policy which is used to compute the gap for
exploration and can be used for deciding which actions to explore at
the exploration state.

Yields:
base_seq: a tf.train.SequenceExample containing a compiled trajectory
base_policy: the policy used to determine the next exploration step
"""
del explore_state
replay_prefix[explore_step] = 1 - replay_prefix[explore_step]
base_policy = ExplorationWithPolicy(
replay_prefix,
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
yield base_seq, base_policy

def _build_replay_prefix_list(self, seq_ex):
ret_list = []
for int_list in seq_ex:
ret_list.append(int_list.int64_list.value[0])
return ret_list

def _create_timestep(self, curr_obs_dict: env.TimeStep):
curr_obs = curr_obs_dict.obs
curr_obs_step = curr_obs_dict.step_type
Expand Down
Loading