Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass entire policy in blackbox_learner #366

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from compiler_opt.rl import corpus
from compiler_opt.es import blackbox_optimizers
from compiler_opt.distributed import buffered_scheduler
from compiler_opt.rl import policy_saver


class BlackboxEvaluator(metaclass=abc.ABCMeta):
Expand All @@ -36,8 +37,8 @@ def __init__(self, train_corpus: corpus.Corpus):

@abc.abstractmethod
def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
) -> List[concurrent.futures.Future]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -66,8 +67,8 @@ def __init__(self, train_corpus: corpus.Corpus,
super().__init__(train_corpus)

def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
) -> List[concurrent.futures.Future]:
if not self._samples:
for _ in range(self._total_num_perturbations):
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
Expand Down
18 changes: 9 additions & 9 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def _save_model(self) -> None:
def get_model_weights(self) -> npt.NDArray[np.float32]:
return self._model_weights

def _get_policy_as_bytes(self,
perturbation: npt.NDArray[np.float32]) -> bytes:
# TODO: The current conversion is inefficient (performance-wise). We should
# consider doing this on the worker side.
def _get_policy_from_perturbation(
self, perturbation: npt.NDArray[np.float32]) -> policy_saver.Policy:
sm = tf.saved_model.load(self._tf_policy_path)
# devectorize the perturbation
policy_utils.set_vectorized_parameters_for_policy(sm, perturbation)
Expand All @@ -242,7 +244,7 @@ def _get_policy_as_bytes(self,

# create and return policy
policy_obj = policy_saver.Policy.from_filesystem(tfl_dir)
return policy_obj.policy
return policy_obj

def run_step(self, pool: FixedWorkerPool) -> None:
"""Run a single step of blackbox learning.
Expand All @@ -258,14 +260,12 @@ def run_step(self, pool: FixedWorkerPool) -> None:
p for p in initial_perturbations for p in (p, -p)
]

# convert to bytes for compile job
# TODO: current conversion is inefficient.
# consider doing this on the worker side
perturbations_as_bytes = []
perturbations_as_policies = []
for perturbation in initial_perturbations:
perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation))
perturbations_as_policies.append(
self._get_policy_from_perturbation(perturbation))

results = self._evaluator.get_results(pool, perturbations_as_bytes)
results = self._evaluator.get_results(pool, perturbations_as_policies)
rewards = self._evaluator.get_rewards(results)

num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)
Expand Down
3 changes: 2 additions & 1 deletion compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(self, arg, *, kwarg):
self._kwarg = kwarg
self.function_value = 0.0

def compile(self, policy: bytes, samples: List[corpus.ModuleSpec]) -> float:
def compile(self, policy: policy_saver.Policy,
samples: List[corpus.ModuleSpec]) -> float:
if policy and samples:
self.function_value += 1.0
return self.function_value
Expand Down