-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor out policy evaluation to separate class
This patch refactors out policy evaluation in blackbox_learner to a new BlackboxEvaluator class so that we can change some details of how things are collected, particularly with regards to sampling and how samples are held.
- Loading branch information
1 parent
804cb40
commit d87815b
Showing
4 changed files
with
175 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tooling computing rewards in blackbox learner.""" | ||
|
||
import abc | ||
import concurrent.futures | ||
from typing import List, Optional | ||
|
||
from absl import logging | ||
import gin | ||
|
||
from compiler_opt.distributed.worker import FixedWorkerPool | ||
from compiler_opt.rl import corpus | ||
from compiler_opt.es import blackbox_optimizers | ||
from compiler_opt.distributed import buffered_scheduler | ||
|
||
|
||
class BlackboxEvaluator(metaclass=abc.ABCMeta): | ||
"""Blockbox evaluator abstraction.""" | ||
|
||
@abc.abstractmethod | ||
def __init__(self, train_corpus: corpus.Corpus): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def get_results( | ||
self, pool: FixedWorkerPool, | ||
perturbations: List[bytes]) -> List[concurrent.futures.Future]: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def set_baseline(self) -> None: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def get_rewards( | ||
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]: | ||
raise NotImplementedError() | ||
|
||
|
||
@gin.configurable | ||
class SamplingBlackboxEvaluator(BlackboxEvaluator): | ||
"""A blackbox evaluator that samples from a corpus to collect reward.""" | ||
|
||
def __init__(self, train_corpus: corpus.Corpus, | ||
est_type: blackbox_optimizers.EstimatorType, | ||
total_num_perturbations: int, num_ir_repeats_within_worker: int): | ||
self._samples = [] | ||
self._train_corpus = train_corpus | ||
self._total_num_perturbations = total_num_perturbations | ||
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker | ||
self._est_type = est_type | ||
|
||
super().__init__(train_corpus) | ||
|
||
def get_results( | ||
self, pool: FixedWorkerPool, | ||
perturbations: List[bytes]) -> List[concurrent.futures.Future]: | ||
if not self._samples: | ||
for _ in range(self._total_num_perturbations): | ||
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker) | ||
self._samples.append(sample) | ||
# add copy of sample for antithetic perturbation pair | ||
if self._est_type == (blackbox_optimizers.EstimatorType.ANTITHETIC): | ||
self._samples.append(sample) | ||
|
||
compile_args = zip(perturbations, self._samples) | ||
|
||
_, futures = buffered_scheduler.schedule_on_worker_pool( | ||
action=lambda w, v: w.compile(v[0], v[1]), | ||
jobs=compile_args, | ||
worker_pool=pool) | ||
|
||
not_done = futures | ||
# wait for all futures to finish | ||
while not_done: | ||
# update lists as work gets done | ||
_, not_done = concurrent.futures.wait( | ||
not_done, return_when=concurrent.futures.FIRST_COMPLETED) | ||
|
||
return futures | ||
|
||
def set_baseline(self) -> None: | ||
pass | ||
|
||
def get_rewards( | ||
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]: | ||
rewards = [None] * len(results) | ||
|
||
for i in range(len(results)): | ||
if not results[i].exception(): | ||
rewards[i] = results[i].result() | ||
else: | ||
logging.info('Error retrieving result from future: %s', | ||
str(results[i].exception())) | ||
|
||
return rewards |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for BlackboxEvaluator.""" | ||
|
||
import concurrent.futures | ||
|
||
from absl.testing import absltest | ||
|
||
from compiler_opt.distributed.local import local_worker_manager | ||
from compiler_opt.rl import corpus | ||
from compiler_opt.es import blackbox_learner_test | ||
from compiler_opt.es import blackbox_evaluator | ||
|
||
|
||
class BlackboxEvaluatorTests(absltest.TestCase): | ||
"""Tests for BlackboxEvaluator.""" | ||
|
||
def test_sampling_get_results(self): | ||
with local_worker_manager.LocalWorkerPoolManager( | ||
blackbox_learner_test.ESWorker, count=3, arg='', kwarg='') as pool: | ||
perturbations = [b'00', b'01', b'10'] | ||
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None) | ||
# pylint: disable=protected-access | ||
evaluator._samples = [[corpus.ModuleSpec(name='name1', size=1)], | ||
[corpus.ModuleSpec(name='name2', size=1)], | ||
[corpus.ModuleSpec(name='name3', size=1)]] | ||
# pylint: enable=protected-access | ||
results = evaluator.get_results(pool, perturbations) | ||
self.assertSequenceAlmostEqual([result.result() for result in results], | ||
[1.0, 1.0, 1.0]) | ||
|
||
def test_sampling_get_rewards(self): | ||
f1 = concurrent.futures.Future() | ||
f1.set_exception(None) | ||
f2 = concurrent.futures.Future() | ||
f2.set_result(2) | ||
results = [f1, f2] | ||
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None) | ||
rewards = evaluator.get_rewards(results) | ||
self.assertEqual(rewards, [None, 2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters