diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py index b37a1f0d..96576d85 100644 --- a/compiler_opt/es/blackbox_learner.py +++ b/compiler_opt/es/blackbox_learner.py @@ -82,6 +82,12 @@ class BlackboxLearnerConfig: step_size: float +@dataclasses.dataclass(frozen=True) +class CorpusSample: + """A sample of a corpus.""" + modules: List[corpus.ModuleSpec] + + def _prune_skipped_perturbations(perturbations: List[npt.NDArray[np.float32]], rewards: List[Optional[float]]): """Remove perturbations that were skipped during the training step. @@ -250,8 +256,9 @@ def _get_results( perturbations: List[bytes]) -> List[concurrent.futures.Future]: if not self._samples: for _ in range(self._config.total_num_perturbations): - sample = self._train_corpus.sample( - self._config.num_ir_repeats_within_worker) + sample = CorpusSample( + self._train_corpus.sample( + self._config.num_ir_repeats_within_worker)) self._samples.append(sample) # add copy of sample for antithetic perturbation pair if self._config.est_type == ( diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py index 12f875f5..da2f0820 100644 --- a/compiler_opt/es/blackbox_learner_test.py +++ b/compiler_opt/es/blackbox_learner_test.py @@ -19,7 +19,6 @@ import concurrent.futures import gin import tempfile -from typing import List import numpy as np import numpy.typing as npt import tensorflow as tf @@ -45,8 +44,9 @@ def __init__(self, arg, *, kwarg): self._kwarg = kwarg self.function_value = 0.0 - def compile(self, policy: bytes, samples: List[corpus.ModuleSpec]) -> float: - if policy and samples: + def compile(self, policy: bytes, + samples: blackbox_learner.CorpusSample) -> float: + if policy and samples.modules: self.function_value += 1.0 return self.function_value else: