Skip to content

Commit

Permalink
Reorganize imports, freeze config, add todo for policy byte conversio…
Browse files Browse the repository at this point in the history
…n, correct type annotations, use join for paths
  • Loading branch information
salaast committed Aug 8, 2023
1 parent dc9d6a5 commit 6a7519d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
31 changes: 18 additions & 13 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Class for coordinating blackbox optimization."""

import os
from absl import logging
import concurrent.futures
import dataclasses
Expand All @@ -26,18 +27,18 @@
from typing import Any, Callable, List, Optional

from compiler_opt.distributed import buffered_scheduler
from compiler_opt.distributed.worker import Worker, FixedWorkerPool
from compiler_opt.es import blackbox_optimizers, policy_utils
from compiler_opt.rl import policy_saver, corpus
from compiler_opt.distributed.worker import FixedWorkerPool
from compiler_opt.es import blackbox_optimizers
from compiler_opt.es import policy_utils
from compiler_opt.rl import corpus
from compiler_opt.rl import policy_saver

# If less than 40% of requests succeed, skip the step.
_SKIP_STEP_SUCCESS_RATIO = 0.4

OUTPUT_SIGNATURE = 'output_spec.json'


@gin.configurable
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class BlackboxLearnerConfig:
"""Hyperparameter configuration for BlackboxLearner."""

Expand Down Expand Up @@ -81,7 +82,8 @@ class BlackboxLearnerConfig:
step_size: float


def _prune_skipped_perturbations(perturbations, rewards):
def _prune_skipped_perturbations(perturbations: List[npt.NDArray[np.float32]],
rewards: List[Optional[float]]):
"""Remove perturbations that were skipped during the training step.
Perturbations may be skipped due to an early exit condition or a server error
Expand Down Expand Up @@ -111,7 +113,7 @@ def _prune_skipped_perturbations(perturbations, rewards):
return len(indices_to_prune)


class BlackboxLearner(Worker):
class BlackboxLearner:
"""Implementation of blackbox learning."""

def __init__(self,
Expand Down Expand Up @@ -251,7 +253,7 @@ def _get_results(
while not_done:
# update lists as work gets done
_, not_done = concurrent.futures.wait(
not_done, return_when='FIRST_COMPLETED')
not_done, return_when=concurrent.futures.FIRST_COMPLETED)

return futures

Expand All @@ -262,13 +264,14 @@ def _get_policy_as_bytes(self,
policy_utils.set_vectorized_parameters_for_policy(sm, perturbation)

with tempfile.TemporaryDirectory() as tmpdir:
sm_dir = tmpdir + '/sm'
sm_dir = os.path.join(tmpdir, 'sm')
tf.saved_model.save(sm, sm_dir, signatures=sm.signatures)
tf.io.gfile.copy(self._tf_policy_path + '/' + OUTPUT_SIGNATURE,
sm_dir + '/' + OUTPUT_SIGNATURE)
src = os.path.join(self._tf_policy_path, policy_saver.OUTPUT_SIGNATURE)
dst = os.path.join(sm_dir, policy_saver.OUTPUT_SIGNATURE)
tf.io.gfile.copy(src, dst)

# convert to tflite
tfl_dir = tmpdir + '/tfl'
tfl_dir = os.path.join(tmpdir, 'tfl')
policy_saver.convert_mlgo_model(sm_dir, tfl_dir)

# create and return policy
Expand All @@ -288,6 +291,8 @@ def run_step(self, pool: FixedWorkerPool) -> None:
]

# convert to bytes for compile job
# TODO: current conversion is inefficient.
# consider doing this on the worker side
perturbations_as_bytes = []
for perturbation in initial_perturbations:
perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation))
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/es/es_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self, arg, *, kwarg):
self._kwarg = kwarg
self.function_value = 0.0

def temp_compile(self, policy: List[bytes],
samples: List[List[corpus.ModuleSpec]]) -> float:
def temp_compile(self, policy: bytes,
samples: List[corpus.ModuleSpec]) -> float:
if policy and samples:
self.function_value += 1.0
return self.function_value
Expand Down

0 comments on commit 6a7519d

Please sign in to comment.