Skip to content

Commit

Permalink
Refactor out policy evaluation to separate class
Browse files Browse the repository at this point in the history
This patch refactors out policy evaluation in blackbox_learner to a new
BlackboxEvaluator class so that we can change some details of how things
are collected, particularly with regards to sampling and how samples are
held.
  • Loading branch information
boomanaiden154 committed Sep 13, 2024
1 parent 804cb40 commit d87815b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 85 deletions.
109 changes: 109 additions & 0 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tooling computing rewards in blackbox learner."""

import abc
import concurrent.futures
from typing import List, Optional

from absl import logging
import gin

from compiler_opt.distributed.worker import FixedWorkerPool
from compiler_opt.rl import corpus
from compiler_opt.es import blackbox_optimizers
from compiler_opt.distributed import buffered_scheduler


class BlackboxEvaluator(metaclass=abc.ABCMeta):
"""Blockbox evaluator abstraction."""

@abc.abstractmethod
def __init__(self, train_corpus: corpus.Corpus):
pass

@abc.abstractmethod
def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
raise NotImplementedError()

@abc.abstractmethod
def set_baseline(self) -> None:
raise NotImplementedError()

@abc.abstractmethod
def get_rewards(
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
raise NotImplementedError()


@gin.configurable
class SamplingBlackboxEvaluator(BlackboxEvaluator):
"""A blackbox evaluator that samples from a corpus to collect reward."""

def __init__(self, train_corpus: corpus.Corpus,
est_type: blackbox_optimizers.EstimatorType,
total_num_perturbations: int, num_ir_repeats_within_worker: int):
self._samples = []
self._train_corpus = train_corpus
self._total_num_perturbations = total_num_perturbations
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
self._est_type = est_type

super().__init__(train_corpus)

def get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
if not self._samples:
for _ in range(self._total_num_perturbations):
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
self._samples.append(sample)
# add copy of sample for antithetic perturbation pair
if self._est_type == (blackbox_optimizers.EstimatorType.ANTITHETIC):
self._samples.append(sample)

compile_args = zip(perturbations, self._samples)

_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, v: w.compile(v[0], v[1]),
jobs=compile_args,
worker_pool=pool)

not_done = futures
# wait for all futures to finish
while not_done:
# update lists as work gets done
_, not_done = concurrent.futures.wait(
not_done, return_when=concurrent.futures.FIRST_COMPLETED)

return futures

def set_baseline(self) -> None:
pass

def get_rewards(
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
rewards = [None] * len(results)

for i in range(len(results)):
if not results[i].exception():
rewards[i] = results[i].result()
else:
logging.info('Error retrieving result from future: %s',
str(results[i].exception()))

return rewards
52 changes: 52 additions & 0 deletions compiler_opt/es/blackbox_evaluator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for BlackboxEvaluator."""

import concurrent.futures

from absl.testing import absltest

from compiler_opt.distributed.local import local_worker_manager
from compiler_opt.rl import corpus
from compiler_opt.es import blackbox_learner_test
from compiler_opt.es import blackbox_evaluator


class BlackboxEvaluatorTests(absltest.TestCase):
"""Tests for BlackboxEvaluator."""

def test_sampling_get_results(self):
with local_worker_manager.LocalWorkerPoolManager(
blackbox_learner_test.ESWorker, count=3, arg='', kwarg='') as pool:
perturbations = [b'00', b'01', b'10']
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
# pylint: disable=protected-access
evaluator._samples = [[corpus.ModuleSpec(name='name1', size=1)],
[corpus.ModuleSpec(name='name2', size=1)],
[corpus.ModuleSpec(name='name3', size=1)]]
# pylint: enable=protected-access
results = evaluator.get_results(pool, perturbations)
self.assertSequenceAlmostEqual([result.result() for result in results],
[1.0, 1.0, 1.0])

def test_sampling_get_rewards(self):
f1 = concurrent.futures.Future()
f1.set_exception(None)
f2 = concurrent.futures.Future()
f2.set_result(2)
results = [f1, f2]
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
rewards = evaluator.get_rewards(results)
self.assertEqual(rewards, [None, 2])
67 changes: 8 additions & 59 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import os
from absl import logging
import concurrent.futures
import dataclasses
import gin
import math
Expand All @@ -26,12 +25,12 @@
import tensorflow as tf
from typing import List, Optional, Protocol

from compiler_opt.distributed import buffered_scheduler
from compiler_opt.distributed.worker import FixedWorkerPool
from compiler_opt.es import blackbox_optimizers
from compiler_opt.es import policy_utils
from compiler_opt.rl import corpus
from compiler_opt.rl import policy_saver
from compiler_opt.es import blackbox_evaluator

# If less than 40% of requests succeed, skip the step.
_SKIP_STEP_SUCCESS_RATIO = 0.4
Expand Down Expand Up @@ -63,14 +62,8 @@ class BlackboxLearnerConfig:
# 0 means all
num_top_directions: int

# How many IR files to try a single perturbation on?
num_ir_repeats_within_worker: int

# How many times should we reuse IR to test different policies?
num_ir_repeats_across_worker: int

# How many IR files to sample from the test corpus at each iteration
num_exact_evals: int
# The type of evaluator to use.
evaluator: type[blackbox_evaluator.BlackboxEvaluator]

# How many perturbations to attempt at each perturbation
total_num_perturbations: int
Expand Down Expand Up @@ -162,12 +155,11 @@ def __init__(self,
self._deadline = deadline
self._seed = seed

# While we're waiting for the ES requests, we can
# collect samples for the next round of training.
self._samples = []

self._summary_writer = tf.summary.create_file_writer(output_dir)

self._evaluator = self._config.evaluator(self._train_corpus,
self._config.est_type)

def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
"""Get perturbations for the model weights."""
perturbations = []
Expand All @@ -178,20 +170,6 @@ def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
self._config.precision_parameter)
return perturbations

def _get_rewards(
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
"""Convert ES results to reward numbers."""
rewards = [None] * len(results)

for i in range(len(results)):
if not results[i].exception():
rewards[i] = results[i].result()
else:
logging.info('Error retrieving result from future: %s',
str(results[i].exception()))

return rewards

def _update_model(self, perturbations: List[npt.NDArray[np.float32]],
rewards: List[float]) -> None:
"""Update the model given a list of perturbations and rewards."""
Expand Down Expand Up @@ -245,35 +223,6 @@ def _save_model(self) -> None:
def get_model_weights(self) -> npt.NDArray[np.float32]:
return self._model_weights

def _get_results(
self, pool: FixedWorkerPool,
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
if not self._samples:
for _ in range(self._config.total_num_perturbations):
sample = self._train_corpus.sample(
self._config.num_ir_repeats_within_worker)
self._samples.append(sample)
# add copy of sample for antithetic perturbation pair
if self._config.est_type == (
blackbox_optimizers.EstimatorType.ANTITHETIC):
self._samples.append(sample)

compile_args = zip(perturbations, self._samples)

_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, v: w.compile(v[0], v[1]),
jobs=compile_args,
worker_pool=pool)

not_done = futures
# wait for all futures to finish
while not_done:
# update lists as work gets done
_, not_done = concurrent.futures.wait(
not_done, return_when=concurrent.futures.FIRST_COMPLETED)

return futures

def _get_policy_as_bytes(self,
perturbation: npt.NDArray[np.float32]) -> bytes:
sm = tf.saved_model.load(self._tf_policy_path)
Expand Down Expand Up @@ -316,8 +265,8 @@ def run_step(self, pool: FixedWorkerPool) -> None:
for perturbation in initial_perturbations:
perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation))

results = self._get_results(pool, perturbations_as_bytes)
rewards = self._get_rewards(results)
results = self._evaluator.get_results(pool, perturbations_as_bytes)
rewards = self._evaluator.get_rewards(results)

num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)
logging.info('Pruned [%d]', num_pruned)
Expand Down
32 changes: 6 additions & 26 deletions compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import os
from absl.testing import absltest
import concurrent.futures
import gin
import tempfile
from typing import List
Expand All @@ -32,6 +31,7 @@
from compiler_opt.es import blackbox_optimizers
from compiler_opt.rl import corpus, inlining, policy_saver, registry
from compiler_opt.rl.inlining import config as inlining_config
from compiler_opt.es import blackbox_evaluator


@gin.configurable
Expand Down Expand Up @@ -59,6 +59,10 @@ class BlackboxLearnerTests(absltest.TestCase):
def setUp(self):
super().setUp()

gin.bind_parameter('SamplingBlackboxEvaluator.total_num_perturbations', 5)
gin.bind_parameter('SamplingBlackboxEvaluator.num_ir_repeats_within_worker',
5)

self._learner_config = blackbox_learner.BlackboxLearnerConfig(
total_steps=1,
blackbox_optimizer=blackbox_optimizers.Algorithm.MONTE_CARLO,
Expand All @@ -67,9 +71,7 @@ def setUp(self):
hyperparameters_update_method=blackbox_optimizers.UpdateMethod
.NO_METHOD,
num_top_directions=0,
num_ir_repeats_within_worker=1,
num_ir_repeats_across_worker=0,
num_exact_evals=1,
evaluator=blackbox_evaluator.SamplingBlackboxEvaluator,
total_num_perturbations=3,
precision_parameter=1,
step_size=1.0)
Expand Down Expand Up @@ -150,28 +152,6 @@ def test_get_perturbations(self):
for value in perturbation:
self.assertAlmostEqual(value, rng.normal())

def test_get_results(self):
with local_worker_manager.LocalWorkerPoolManager(
ESWorker, count=3, arg='', kwarg='') as pool:
self._samples = [[corpus.ModuleSpec(name='name1', size=1)],
[corpus.ModuleSpec(name='name2', size=1)],
[corpus.ModuleSpec(name='name3', size=1)]]
perturbations = [b'00', b'01', b'10']
# pylint: disable=protected-access
results = self._learner._get_results(pool, perturbations)
# pylint: enable=protected-access
self.assertSequenceAlmostEqual([result.result() for result in results],
[1.0, 1.0, 1.0])

def test_get_rewards(self):
f1 = concurrent.futures.Future()
f1.set_exception(None)
f2 = concurrent.futures.Future()
f2.set_result(2)
results = [f1, f2]
rewards = self._learner._get_rewards(results) # pylint: disable=protected-access
self.assertEqual(rewards, [None, 2])

def test_prune_skipped_perturbations(self):
perturbations = [1, 2, 3, 4, 5]
rewards = [1, None, 1, None, 1]
Expand Down

0 comments on commit d87815b

Please sign in to comment.