From 393bd273aabff2e6c7d4cff3de76d5c171fd62a9 Mon Sep 17 00:00:00 2001 From: vizier-team Date: Tue, 1 Aug 2023 15:29:11 -0700 Subject: [PATCH] Internal Change PiperOrigin-RevId: 552941861 --- .../_src/algorithms/designers/gp/gp_models.py | 13 +- .../algorithms/designers/gp/gp_models_test.py | 12 +- vizier/_src/algorithms/designers/gp_bandit.py | 164 +++++++++++++++--- .../algorithms/designers/gp_bandit_test.py | 163 ++++++++++++++--- 4 files changed, 293 insertions(+), 59 deletions(-) diff --git a/vizier/_src/algorithms/designers/gp/gp_models.py b/vizier/_src/algorithms/designers/gp/gp_models.py index ff25a22ff..eadcf8bd3 100644 --- a/vizier/_src/algorithms/designers/gp/gp_models.py +++ b/vizier/_src/algorithms/designers/gp/gp_models.py @@ -277,13 +277,14 @@ def train_gp( """Trains a Gaussian Process model. If `spec` contains multiple elements, each will be used to train a - `StackedResidualGP`, sequentially. The last entry will be used to train the + `StackedResidualGP`, sequentially. The first entry will be used to train the first GP, and then subsequent GPs will be trained on the residuals from the - previous GP. This process completes in reverse order, such that `spec[-1]` is - the first GP trained and `spec[0]` is the last GP trained. + previous GP. This process completes in the order that `spec` and `data are + provided, such that `spec[0]` is the first GP trained and `spec[-1]` is the + last GP trained. - spec[0] and data[0] make up the top-level GP, and spec[1:] and data[1:] define - the priors in context of transfer learning. + spec[-1] and data[-1] make up the top-level GP, and spec[:-1] and data[:-1] + define the priors in context of transfer learning. Args: spec: Specification for how to train a GP model. If multiple specs are @@ -314,7 +315,7 @@ def train_gp( ) curr_gp: Optional[GPState] = None - for curr_spec, curr_data in reversed(list(zip(spec, data))): + for curr_spec, curr_data in zip(spec, data): if curr_gp is None: # We are on the first iteration. curr_gp = _train_gp(spec=curr_spec, data=curr_data) diff --git a/vizier/_src/algorithms/designers/gp/gp_models_test.py b/vizier/_src/algorithms/designers/gp/gp_models_test.py index 54c346eed..6624ff4c1 100644 --- a/vizier/_src/algorithms/designers/gp/gp_models_test.py +++ b/vizier/_src/algorithms/designers/gp/gp_models_test.py @@ -191,7 +191,7 @@ def test_sequential_base_accuracy( # Combine the good base and the bad top into transfer learning GP. seq_base_gp = gp_models.train_gp( - [top_spec, base_spec], [top_train_data, base_train_data] + [base_spec, top_spec], [base_train_data, top_train_data] ) # Create a purposefully-bad GP with `bad_num_samples` for comparison. @@ -244,8 +244,8 @@ def test_multi_base( ensemble_size=ensemble_size, ) - train_specs = [top_spec] - train_data = [top_train_data] + train_specs = [] + train_data = [] for _ in range(2): base_spec, base_train_data, _ = _setup_lambda_search( @@ -257,6 +257,8 @@ def test_multi_base( ) train_specs.append(base_spec) train_data.append(base_train_data) + train_specs.append(top_spec) + train_data.append(top_train_data) seq_base_gp = gp_models.train_gp(train_specs, train_data) @@ -323,10 +325,10 @@ def test_bad_base_resilience( # Combine the good base and the bad top into transfer learning GP. seq_base_gp = gp_models.train_gp( [ - top_spec, bad_base_spec, + top_spec, ], - [top_train_data, bad_base_train_data], + [bad_base_train_data, top_train_data], ) # Create a GP on the fake objective with sufficient training data diff --git a/vizier/_src/algorithms/designers/gp_bandit.py b/vizier/_src/algorithms/designers/gp_bandit.py index 8895755d0..9e45e6efb 100644 --- a/vizier/_src/algorithms/designers/gp_bandit.py +++ b/vizier/_src/algorithms/designers/gp_bandit.py @@ -129,6 +129,17 @@ class VizierGPBandit(vza.Designer, vza.Predictor): _last_computed_gp: gp_models.GPState = attr.field(init=False) + # The studies used in transfer learning. Ordered in training order, i.e. + # a GP is trained on `_prior_studies[0]` first, then one is trained on the + # residuals of `_prior_studies[1]` from the GP trained on `_prior_studies[0]`, + # and so on. + _prior_studies: list[vza.CompletedTrials] = attr.field( + factory=list, init=False + ) + _incorporated_prior_study_count: int = attr.field( + default=0, kw_only=True, init=False + ) + default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory( strategy_factory=es.VectorizedEagleStrategyFactory() ) @@ -205,6 +216,32 @@ def update( del all_active self._trials.extend(copy.deepcopy(completed.trials)) + def update_priors(self, prior_studies: Sequence[vza.CompletedTrials]) -> None: + """Updates the list of prior studies for transfer learning. + + Each element is treated as a new prior study, and will be stacked in order + received - i.e. the first entry is for the first GP, the second entry is for + the GP trained on the residuals of the first GP, etc. + + See section 3.3 of https://dl.acm.org/doi/10.1145/3097983.3098043 for more + information, or see `gp/gp_models.py` and `gp/transfer_learning.py` + + Transfer learning is resilient to bad priors. + + Multiple calls are permitted. It is up to the caller to ensure + `prior_studies` have a matching `ProblemStatement`, otherwise behavior is + undefined. + + TODO: Decide on whether this method should become part of an + interface. + + Args: + prior_studies: A list of lists of completed trials, with one list per + prior study. The designer will train a prior GP for each list of prior + trials (for each `CompletedStudy` entry), in the order received. + """ + self._prior_studies.extend(copy.deepcopy(prior_studies)) + @property def _metric_info(self) -> vz.MetricInformation: return self._problem.metric_information.item() @@ -286,46 +323,103 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData: return types.ModelData(model_data.features, labels) @_experimental_override_allowed - def _train_gp( + def _create_gp_spec( self, data: types.ModelData, ard_rng: jax.random.KeyArray - ) -> gp_models.GPState: - """Overrideable training of a pre-computed ensemble GP.""" - trained_gp = gp_models.train_gp( - spec=gp_models.GPTrainingSpec( - ard_optimizer=self._ard_optimizer, - ard_rng=ard_rng, - coroutine=gp_models.get_vizier_gp_coroutine( - features=data.features, linear_coef=self._linear_coef - ), - ensemble_size=self._ensemble_size, - ard_random_restarts=self._ard_random_restarts, + ) -> gp_models.GPTrainingSpec: + """Overrideable creation of a training spec for a GP model.""" + return gp_models.GPTrainingSpec( + ard_optimizer=self._ard_optimizer, + ard_rng=ard_rng, + coroutine=gp_models.get_vizier_gp_coroutine( + features=data.features, linear_coef=self._linear_coef ), - data=data, + ensemble_size=self._ensemble_size, + ard_random_restarts=self._ard_random_restarts, ) - return trained_gp + + @_experimental_override_allowed + def _train_gp_with_priors( + self, + data: types.ModelData, + ard_rng: jax.random.KeyArray, + priors: Sequence[types.ModelData], + ): + """Trains a transfer-learning-enabled GP with prior studies. + + Args: + data: top-level data on which to train a GP. + ard_rng: RNG to do ARD to optimize GP parameters. + priors: Data for each sequential prior to train for transfer learning. + Assumed to be in order of training, i.e. element 0 is priors[0] is the + first GP trained, and priors[1] trains a GP on the residuals of the GP + trained on priors[0], and so on. + + Returns: + A trained pre-computed ensemble GP. + """ + ard_rngs = jax.random.split(ard_rng, len(priors) + 1) + + # Order `specs` in training order, i.e. `specs[0]` is trained first. + specs = [ + self._create_gp_spec(prior_data, ard_rngs[i]) + for i, prior_data in enumerate(priors) + ] + + # Use the last rng for the top level spec. + specs.append(self._create_gp_spec(data, ard_rngs[-1])) + + # Order `training_data` in training order, i.e. `training_data[0]` is + # trained first. + training_data = list(priors) + training_data.append(data) + + # `train_gp` expects `specs` and `data` in training order, which is how + # they were prepared above. + return gp_models.train_gp(spec=specs, data=training_data) @profiler.record_runtime - def _update_gp(self, data: types.ModelData) -> gp_models.GPState: + def _update_gp( + self, + data: types.ModelData, + *, + prior_data: Optional[Sequence[types.ModelData]] = None, + ) -> gp_models.GPState: """Compute the designer's GP and caches the result. No-op without new data. Args: data: Data to go into GP. + prior_data: Data to train priors on, in training order. Returns: - GPBanditState object containing the designer's state. + `GPState` object containing the trained GP. 1. Convert trials to features and labels. 2. Trains a pre-computed ensemble GP. If no new trials were added since last call, no update will occur. """ - if len(self._trials) == self._incorporated_trials_count: - # If there's no change in the number of completed trials, don't update - # state. The assumption is that trials can't be removed. + if ( + len(self._trials) == self._incorporated_trials_count + and len(self._prior_studies) == self._incorporated_prior_study_count + ): + # If there's no change in the number of completed trials or the number of + # priors, don't update state. The assumption is that trials can't be + # removed. return self._last_computed_gp self._incorporated_trials_count = len(self._trials) + self._incorporated_prior_study_count = len(self._prior_studies) self._rng, ard_rng = jax.random.split(self._rng, 2) - self._last_computed_gp = self._train_gp(data=data, ard_rng=ard_rng) + + if not prior_data: + self._last_computed_gp = gp_models.train_gp( + spec=self._create_gp_spec(data, ard_rng), + data=data, + ) + else: + self._last_computed_gp = self._train_gp_with_priors( + data=data, ard_rng=ard_rng, priors=prior_data + ) + return self._last_computed_gp @_experimental_override_allowed @@ -379,6 +473,29 @@ def _optimize_acquisition( best_candidates, self._converter ) # [N, D] + @profiler.record_runtime + @_experimental_override_allowed + def _generate_data( + self, + ) -> tuple[types.ModelData, Optional[list[types.ModelData]]]: + """Converts trials to top-level and prior training data.""" + prior_data: Optional[list[types.ModelData]] = None + if self._prior_studies: + prior_data = [ + self._trials_to_data(prior_study.trials) + for prior_study in self._prior_studies + ] + + # The top level data must be converted last - because `_output_warper` + # depends on the support points that were supplied to it in `warp` to + # `unwarp` labels. It stores these support points each time `warp` is + # called, so the last `warp` call dictates the support points used in + # `unwarp`. Therefore, since we want to `unwarp` the predictions based off + # the current (top) study rather than any prior study, we need to call + # `warp` on the current study last. + data = self._trials_to_data(self._trials) + return data, prior_data + @profiler.record_runtime def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]: logging.info('Suggest called with count=%d', count) @@ -393,8 +510,8 @@ def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]: suggest_start_time = datetime.datetime.now() logging.info('Updating the designer state based on trials...') - data = self._trials_to_data(self._trials) - gp = self._update_gp(data) + data, prior_data = self._generate_data() + gp = self._update_gp(data, prior_data=prior_data) # Define acquisition function. scoring_fn = self._scoring_function_factory( @@ -437,7 +554,8 @@ def sample( if not trials: return np.zeros((num_samples, 0)) - gp = self._update_gp(self._trials_to_data(self._trials)) + data, prior_data = self._generate_data() + gp = self._update_gp(data, prior_data=prior_data) xs = self._converter.to_features(trials) xs = types.ModelInput( continuous=xs.continuous.replace_fill_value(0.0), diff --git a/vizier/_src/algorithms/designers/gp_bandit_test.py b/vizier/_src/algorithms/designers/gp_bandit_test.py index c0166e012..8720a7ce0 100644 --- a/vizier/_src/algorithms/designers/gp_bandit_test.py +++ b/vizier/_src/algorithms/designers/gp_bandit_test.py @@ -16,6 +16,7 @@ """Tests for gp_bandit.""" +from typing import Callable from unittest import mock import jax @@ -47,6 +48,69 @@ def _build_mock_continuous_array_specs(n): return [continuous_spec] * n +def _setup_lambda_search( + f: Callable[[float], float], num_trials: int = 100 +) -> tuple[gp_bandit.VizierGPBandit, list[vz.Trial]]: + """Sets up a GP designer and outputs completed studies for `f`. + + Args: + f: 1D objective to be optimized, i.e. f(x), where x is a scalar in [-5., 5.) + num_trials: Number of mock "evaluated" trials to return. + + Returns: + A GP designer set up for the problem of optimizing the objective, without any + data updated. + Evaluated trials against `f`. + """ + assert ( + num_trials > 0 + ), f'Must provide a positive number of trials. Got {num_trials}.' + + search_space = vz.SearchSpace() + search_space.root.add_float_param('x0', -5.0, 5.0) + problem = vz.ProblemStatement( + search_space=search_space, + metric_information=vz.MetricsConfig( + metrics=[ + vz.MetricInformation('obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE), + ] + ), + ) + + suggestions = quasi_random.QuasiRandomDesigner( + problem.search_space, seed=1 + ).suggest(num_trials) + + obs_trials = [] + for idx, suggestion in enumerate(suggestions): + trial = suggestion.to_trial(idx) + x = suggestions[idx].parameters['x0'].value + trial.complete(vz.Measurement(metrics={'obj': f(x)})) + obs_trials.append(trial) + + gp_designer = gp_bandit.VizierGPBandit(problem, ard_optimizer=ard_optimizer) + return gp_designer, obs_trials + + +def _compute_mse( + designer: gp_bandit.VizierGPBandit, + test_trials: list[vz.Trial], + y_test: list[float], +) -> float: + """Evaluate the designer's accuracy on the test set. + + Args: + designer: The GP bandit designer to predict from. + test_trials: The trials of the test set + y_test: The results of the test set + + Returns: + The MSE of `designer` on `test_trials` and `y_test` + """ + preds = designer.predict(test_trials) + return np.sum(np.square(preds.mean - y_test)) + + class GoogleGpBanditTest(parameterized.TestCase): @parameterized.parameters( @@ -216,32 +280,8 @@ def test_on_flat_mixed_space( self.assertFalse(np.isnan(prediction.stddev).any()) def test_prediction_accuracy(self): - search_space = vz.SearchSpace() - search_space.root.add_float_param('x0', -5.0, 5.0) - problem = vz.ProblemStatement( - search_space=search_space, - metric_information=vz.MetricsConfig( - metrics=[ - vz.MetricInformation( - 'obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE - ), - ] - ), - ) f = lambda x: -((x - 0.5) ** 2) - - suggestions = quasi_random.QuasiRandomDesigner( - problem.search_space, seed=1 - ).suggest(100) - - obs_trials = [] - for idx, suggestion in enumerate(suggestions): - trial = suggestion.to_trial(idx) - x = suggestions[idx].parameters['x0'].value - trial.complete(vz.Measurement(metrics={'obj': f(x)})) - obs_trials.append(trial) - - gp_designer = gp_bandit.VizierGPBandit(problem, ard_optimizer=ard_optimizer) + gp_designer, obs_trials = _setup_lambda_search(f) gp_designer.update(vza.CompletedTrials(obs_trials), vza.ActiveTrials()) pred_trial = vz.Trial({'x0': 0.0}) pred = gp_designer.predict([pred_trial]) @@ -261,6 +301,7 @@ def test_jit_once(self, *args): name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE ) ) + def create_designer(problem): return gp_bandit.VizierGPBandit( problem=problem, @@ -298,6 +339,78 @@ def create_runner(problem): designer2 = create_designer(problem) create_runner(problem).run_designer(designer2) + def test_priors_work(self): + f = lambda x: -((x - 0.5) ** 2) + + # X is in range of what is defined in `_setup_lambda_search`, [-5.0, 5.0) + x_test = np.random.default_rng(1).uniform(-5.0, 5.0, 100) + y_test = [f(x) for x in x_test] + test_trials = [vz.Trial({'x0': x}) for x in x_test] + + # Create the designer with a prior and the trials to train the prior. + gp_designer_with_prior, obs_trials_for_prior = _setup_lambda_search( + f=f, num_trials=100 + ) + + # Update prior with above trials. + gp_designer_with_prior.update_priors( + [vza.CompletedTrials(obs_trials_for_prior)] + ) + + # Purposefully set a low number of trials for the actual study, so that + # the designer without the prior will predict with poor accuracy. + gp_designer_no_prior, obs_trials = _setup_lambda_search(f=f, num_trials=20) + + # Update both priors with the actual study. + gp_designer_no_prior.update( + vza.CompletedTrials(obs_trials), vza.ActiveTrials() + ) + gp_designer_with_prior.update( + vza.CompletedTrials(obs_trials), vza.ActiveTrials() + ) + + # Evaluate the no prior designer's accuracy on the test set. + mse_no_prior = _compute_mse(gp_designer_no_prior, test_trials, y_test) + + # Evaluate the designer with prior's accuracy on the test set. + mse_with_prior = _compute_mse(gp_designer_with_prior, test_trials, y_test) + + # The designer with a prior should predict better. + self.assertLess(mse_with_prior, mse_no_prior) + + def test_multiple_priors(self): + """Tests that a multi-prior GP predicts better than a GP with one prior.""" + f = lambda x: -((x - 0.5) ** 2) + multi_prior_gp_designer, multi_prior_trials = _setup_lambda_search( + f, num_trials=300 + ) + prior_0, prior_1, top = np.array_split(multi_prior_trials, 3) + multi_prior_gp_designer.update_priors([vza.CompletedTrials(prior_0)]) + multi_prior_gp_designer.update_priors([vza.CompletedTrials(prior_1)]) + multi_prior_gp_designer.update(vza.CompletedTrials(top), vza.ActiveTrials()) + self.assertLen(multi_prior_gp_designer._prior_studies, 2) + self.assertLen(multi_prior_gp_designer._trials, len(top)) + + single_prior_gp_designer, single_prior_trials = _setup_lambda_search( + f, num_trials=200 + ) + prior, top = np.array_split(single_prior_trials, 2) + single_prior_gp_designer.update_priors([vza.CompletedTrials(prior)]) + single_prior_gp_designer.update( + vza.CompletedTrials(top), vza.ActiveTrials() + ) + self.assertLen(single_prior_gp_designer._prior_studies, 1) + self.assertLen(single_prior_gp_designer._trials, len(top)) + + x_test = np.random.default_rng(1).uniform(-5.0, 5.0, 100) + y_test = [f(x) for x in x_test] + test_trials = [vz.Trial({'x0': x}) for x in x_test] + multi_prior_mse = _compute_mse(multi_prior_gp_designer, test_trials, y_test) + single_prior_mse = _compute_mse( + single_prior_gp_designer, test_trials, y_test + ) + self.assertLess(multi_prior_mse, single_prior_mse) + if __name__ == '__main__': # Jax disables float64 computations by default and will silently convert