diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 2bd68c51..787ae61d 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -25,6 +25,7 @@ from compiler_opt.rl import corpus from compiler_opt.es import blackbox_optimizers from compiler_opt.distributed import buffered_scheduler +from compiler_opt.rl import policy_saver class BlackboxEvaluator(metaclass=abc.ABCMeta): @@ -36,8 +37,8 @@ def __init__(self, train_corpus: corpus.Corpus): @abc.abstractmethod def get_results( - self, pool: FixedWorkerPool, - perturbations: List[bytes]) -> List[concurrent.futures.Future]: + self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy] + ) -> List[concurrent.futures.Future]: raise NotImplementedError() @abc.abstractmethod @@ -66,8 +67,8 @@ def __init__(self, train_corpus: corpus.Corpus, super().__init__(train_corpus) def get_results( - self, pool: FixedWorkerPool, - perturbations: List[bytes]) -> List[concurrent.futures.Future]: + self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy] + ) -> List[concurrent.futures.Future]: if not self._samples: for _ in range(self._total_num_perturbations): sample = self._train_corpus.sample(self._num_ir_repeats_within_worker) diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py index 984ef5a3..c4e82b62 100644 --- a/compiler_opt/es/blackbox_learner.py +++ b/compiler_opt/es/blackbox_learner.py @@ -223,8 +223,10 @@ def _save_model(self) -> None: def get_model_weights(self) -> npt.NDArray[np.float32]: return self._model_weights - def _get_policy_as_bytes(self, - perturbation: npt.NDArray[np.float32]) -> bytes: + # TODO: The current conversion is inefficient (performance-wise). We should + # consider doing this on the worker side. + def _get_policy_from_perturbation( + self, perturbation: npt.NDArray[np.float32]) -> policy_saver.Policy: sm = tf.saved_model.load(self._tf_policy_path) # devectorize the perturbation policy_utils.set_vectorized_parameters_for_policy(sm, perturbation) @@ -242,7 +244,7 @@ def _get_policy_as_bytes(self, # create and return policy policy_obj = policy_saver.Policy.from_filesystem(tfl_dir) - return policy_obj.policy + return policy_obj def run_step(self, pool: FixedWorkerPool) -> None: """Run a single step of blackbox learning. @@ -258,14 +260,12 @@ def run_step(self, pool: FixedWorkerPool) -> None: p for p in initial_perturbations for p in (p, -p) ] - # convert to bytes for compile job - # TODO: current conversion is inefficient. - # consider doing this on the worker side - perturbations_as_bytes = [] + perturbations_as_policies = [] for perturbation in initial_perturbations: - perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation)) + perturbations_as_policies.append( + self._get_policy_from_perturbation(perturbation)) - results = self._evaluator.get_results(pool, perturbations_as_bytes) + results = self._evaluator.get_results(pool, perturbations_as_policies) rewards = self._evaluator.get_rewards(results) num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards) diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py index cb5667c5..5f74a13a 100644 --- a/compiler_opt/es/blackbox_learner_test.py +++ b/compiler_opt/es/blackbox_learner_test.py @@ -45,7 +45,8 @@ def __init__(self, arg, *, kwarg): self._kwarg = kwarg self.function_value = 0.0 - def compile(self, policy: bytes, samples: List[corpus.ModuleSpec]) -> float: + def compile(self, policy: policy_saver.Policy, + samples: List[corpus.ModuleSpec]) -> float: if policy and samples: self.function_value += 1.0 return self.function_value