Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal Change #833

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions vizier/_src/algorithms/designers/gp/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,14 @@ def train_gp(
"""Trains a Gaussian Process model.

If `spec` contains multiple elements, each will be used to train a
`StackedResidualGP`, sequentially. The last entry will be used to train the
`StackedResidualGP`, sequentially. The first entry will be used to train the
first GP, and then subsequent GPs will be trained on the residuals from the
previous GP. This process completes in reverse order, such that `spec[-1]` is
the first GP trained and `spec[0]` is the last GP trained.
previous GP. This process completes in the order that `spec` and `data are
provided, such that `spec[0]` is the first GP trained and `spec[-1]` is the
last GP trained.

spec[0] and data[0] make up the top-level GP, and spec[1:] and data[1:] define
the priors in context of transfer learning.
spec[-1] and data[-1] make up the top-level GP, and spec[:-1] and data[:-1]
define the priors in context of transfer learning.

Args:
spec: Specification for how to train a GP model. If multiple specs are
Expand Down Expand Up @@ -314,7 +315,7 @@ def train_gp(
)

curr_gp: Optional[GPState] = None
for curr_spec, curr_data in reversed(list(zip(spec, data))):
for curr_spec, curr_data in zip(spec, data):
if curr_gp is None:
# We are on the first iteration.
curr_gp = _train_gp(spec=curr_spec, data=curr_data)
Expand Down
12 changes: 7 additions & 5 deletions vizier/_src/algorithms/designers/gp/gp_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_sequential_base_accuracy(

# Combine the good base and the bad top into transfer learning GP.
seq_base_gp = gp_models.train_gp(
[top_spec, base_spec], [top_train_data, base_train_data]
[base_spec, top_spec], [base_train_data, top_train_data]
)

# Create a purposefully-bad GP with `bad_num_samples` for comparison.
Expand Down Expand Up @@ -244,8 +244,8 @@ def test_multi_base(
ensemble_size=ensemble_size,
)

train_specs = [top_spec]
train_data = [top_train_data]
train_specs = []
train_data = []

for _ in range(2):
base_spec, base_train_data, _ = _setup_lambda_search(
Expand All @@ -257,6 +257,8 @@ def test_multi_base(
)
train_specs.append(base_spec)
train_data.append(base_train_data)
train_specs.append(top_spec)
train_data.append(top_train_data)

seq_base_gp = gp_models.train_gp(train_specs, train_data)

Expand Down Expand Up @@ -323,10 +325,10 @@ def test_bad_base_resilience(
# Combine the good base and the bad top into transfer learning GP.
seq_base_gp = gp_models.train_gp(
[
top_spec,
bad_base_spec,
top_spec,
],
[top_train_data, bad_base_train_data],
[bad_base_train_data, top_train_data],
)

# Create a GP on the fake objective with sufficient training data
Expand Down
164 changes: 141 additions & 23 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ class VizierGPBandit(vza.Designer, vza.Predictor):

_last_computed_gp: gp_models.GPState = attr.field(init=False)

# The studies used in transfer learning. Ordered in training order, i.e.
# a GP is trained on `_prior_studies[0]` first, then one is trained on the
# residuals of `_prior_studies[1]` from the GP trained on `_prior_studies[0]`,
# and so on.
_prior_studies: list[vza.CompletedTrials] = attr.field(
factory=list, init=False
)
_incorporated_prior_study_count: int = attr.field(
default=0, kw_only=True, init=False
)

default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory(
strategy_factory=es.VectorizedEagleStrategyFactory()
)
Expand Down Expand Up @@ -205,6 +216,32 @@ def update(
del all_active
self._trials.extend(copy.deepcopy(completed.trials))

def update_priors(self, prior_studies: Sequence[vza.CompletedTrials]) -> None:
"""Updates the list of prior studies for transfer learning.

Each element is treated as a new prior study, and will be stacked in order
received - i.e. the first entry is for the first GP, the second entry is for
the GP trained on the residuals of the first GP, etc.

See section 3.3 of https://dl.acm.org/doi/10.1145/3097983.3098043 for more
information, or see `gp/gp_models.py` and `gp/transfer_learning.py`

Transfer learning is resilient to bad priors.

Multiple calls are permitted. It is up to the caller to ensure
`prior_studies` have a matching `ProblemStatement`, otherwise behavior is
undefined.

TODO: Decide on whether this method should become part of an
interface.

Args:
prior_studies: A list of lists of completed trials, with one list per
prior study. The designer will train a prior GP for each list of prior
trials (for each `CompletedStudy` entry), in the order received.
"""
self._prior_studies.extend(copy.deepcopy(prior_studies))

@property
def _metric_info(self) -> vz.MetricInformation:
return self._problem.metric_information.item()
Expand Down Expand Up @@ -286,46 +323,103 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
return types.ModelData(model_data.features, labels)

@_experimental_override_allowed
def _train_gp(
def _create_gp_spec(
self, data: types.ModelData, ard_rng: jax.random.KeyArray
) -> gp_models.GPState:
"""Overrideable training of a pre-computed ensemble GP."""
trained_gp = gp_models.train_gp(
spec=gp_models.GPTrainingSpec(
ard_optimizer=self._ard_optimizer,
ard_rng=ard_rng,
coroutine=gp_models.get_vizier_gp_coroutine(
features=data.features, linear_coef=self._linear_coef
),
ensemble_size=self._ensemble_size,
ard_random_restarts=self._ard_random_restarts,
) -> gp_models.GPTrainingSpec:
"""Overrideable creation of a training spec for a GP model."""
return gp_models.GPTrainingSpec(
ard_optimizer=self._ard_optimizer,
ard_rng=ard_rng,
coroutine=gp_models.get_vizier_gp_coroutine(
features=data.features, linear_coef=self._linear_coef
),
data=data,
ensemble_size=self._ensemble_size,
ard_random_restarts=self._ard_random_restarts,
)
return trained_gp

@_experimental_override_allowed
def _train_gp_with_priors(
self,
data: types.ModelData,
ard_rng: jax.random.KeyArray,
priors: Sequence[types.ModelData],
):
"""Trains a transfer-learning-enabled GP with prior studies.

Args:
data: top-level data on which to train a GP.
ard_rng: RNG to do ARD to optimize GP parameters.
priors: Data for each sequential prior to train for transfer learning.
Assumed to be in order of training, i.e. element 0 is priors[0] is the
first GP trained, and priors[1] trains a GP on the residuals of the GP
trained on priors[0], and so on.

Returns:
A trained pre-computed ensemble GP.
"""
ard_rngs = jax.random.split(ard_rng, len(priors) + 1)

# Order `specs` in training order, i.e. `specs[0]` is trained first.
specs = [
self._create_gp_spec(prior_data, ard_rngs[i])
for i, prior_data in enumerate(priors)
]

# Use the last rng for the top level spec.
specs.append(self._create_gp_spec(data, ard_rngs[-1]))

# Order `training_data` in training order, i.e. `training_data[0]` is
# trained first.
training_data = list(priors)
training_data.append(data)

# `train_gp` expects `specs` and `data` in training order, which is how
# they were prepared above.
return gp_models.train_gp(spec=specs, data=training_data)

@profiler.record_runtime
def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
def _update_gp(
self,
data: types.ModelData,
*,
prior_data: Optional[Sequence[types.ModelData]] = None,
) -> gp_models.GPState:
"""Compute the designer's GP and caches the result. No-op without new data.

Args:
data: Data to go into GP.
prior_data: Data to train priors on, in training order.

Returns:
GPBanditState object containing the designer's state.
`GPState` object containing the trained GP.

1. Convert trials to features and labels.
2. Trains a pre-computed ensemble GP.

If no new trials were added since last call, no update will occur.
"""
if len(self._trials) == self._incorporated_trials_count:
# If there's no change in the number of completed trials, don't update
# state. The assumption is that trials can't be removed.
if (
len(self._trials) == self._incorporated_trials_count
and len(self._prior_studies) == self._incorporated_prior_study_count
):
# If there's no change in the number of completed trials or the number of
# priors, don't update state. The assumption is that trials can't be
# removed.
return self._last_computed_gp
self._incorporated_trials_count = len(self._trials)
self._incorporated_prior_study_count = len(self._prior_studies)
self._rng, ard_rng = jax.random.split(self._rng, 2)
self._last_computed_gp = self._train_gp(data=data, ard_rng=ard_rng)

if not prior_data:
self._last_computed_gp = gp_models.train_gp(
spec=self._create_gp_spec(data, ard_rng),
data=data,
)
else:
self._last_computed_gp = self._train_gp_with_priors(
data=data, ard_rng=ard_rng, priors=prior_data
)

return self._last_computed_gp

@_experimental_override_allowed
Expand Down Expand Up @@ -379,6 +473,29 @@ def _optimize_acquisition(
best_candidates, self._converter
) # [N, D]

@profiler.record_runtime
@_experimental_override_allowed
def _generate_data(
self,
) -> tuple[types.ModelData, Optional[list[types.ModelData]]]:
"""Converts trials to top-level and prior training data."""
prior_data: Optional[list[types.ModelData]] = None
if self._prior_studies:
prior_data = [
self._trials_to_data(prior_study.trials)
for prior_study in self._prior_studies
]

# The top level data must be converted last - because `_output_warper`
# depends on the support points that were supplied to it in `warp` to
# `unwarp` labels. It stores these support points each time `warp` is
# called, so the last `warp` call dictates the support points used in
# `unwarp`. Therefore, since we want to `unwarp` the predictions based off
# the current (top) study rather than any prior study, we need to call
# `warp` on the current study last.
data = self._trials_to_data(self._trials)
return data, prior_data

@profiler.record_runtime
def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
logging.info('Suggest called with count=%d', count)
Expand All @@ -393,8 +510,8 @@ def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:

suggest_start_time = datetime.datetime.now()
logging.info('Updating the designer state based on trials...')
data = self._trials_to_data(self._trials)
gp = self._update_gp(data)
data, prior_data = self._generate_data()
gp = self._update_gp(data, prior_data=prior_data)

# Define acquisition function.
scoring_fn = self._scoring_function_factory(
Expand Down Expand Up @@ -437,7 +554,8 @@ def sample(
if not trials:
return np.zeros((num_samples, 0))

gp = self._update_gp(self._trials_to_data(self._trials))
data, prior_data = self._generate_data()
gp = self._update_gp(data, prior_data=prior_data)
xs = self._converter.to_features(trials)
xs = types.ModelInput(
continuous=xs.continuous.replace_fill_value(0.0),
Expand Down
Loading