From e4a15daa7a9e3b25891d520d8711cde5fa8acd0f Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Fri, 18 Jun 2021 14:12:56 -0400 Subject: [PATCH 01/13] Initial commit for Tree-structured Parzen Estimator feature --- bayes_search.py | 125 +++++++++++++++++++++++++++++++++++-- tests/test_bayes_search.py | 10 ++- 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/bayes_search.py b/bayes_search.py index f4ef58bc..465d5960 100644 --- a/bayes_search.py +++ b/bayes_search.py @@ -154,10 +154,11 @@ def next_sample( X_bounds: Optional[npt.ArrayLike] = None, current_X: Optional[npt.ArrayLike] = None, nu: floating = 1.5, - max_samples_for_gp: integer = 100, + max_samples_for_model: integer = 100, improvement: floating = 0.01, num_points_to_try: integer = 1000, opt_func: str = "expected_improvement", + model: str = "gp", test_X: Optional[npt.ArrayLike] = None, ) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: """Calculates the best next sample to look at via bayesian optimization. @@ -177,7 +178,7 @@ def next_sample( http://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.kernels.Matern.html - max_samples_for_gp: integer, optional, default 100 + max_samples_for_model: integer, optional, default 100 maximum samples to consider (since algo is O(n^3)) for performance, but also adds some randomness. this number of samples will be chosen randomly from the sample_X and used to train the GP. @@ -190,6 +191,8 @@ def next_sample( improvement of probability of improvement. Expected improvement is generally better - may want to remove probability of improvement at some point. (But I think prboability of improvement is a little easier to calculate) + model: one of {"gp", "tpe"} - whether to use a Gaussian Process as a surrogate model or + a Tree-structured Parzen Estimator test_X: X values to test when looking for the best values to try Returns: @@ -246,9 +249,46 @@ def next_sample( None, ) + if model=="gp": + return next_sample_gp( + filtered_X=filtered_X, + filtered_y=filtered_y, + X_bounds=X_bounds, + current_X=current_X, + nu=nu, + max_samples_for_model=max_samples_for_model, + improvement=improvement, + num_points_to_try=num_points_to_try, + opt_func=opt_func, + test_X=test_X + ) + elif model=="tpe": + return next_sample_tpe( + filtered_X=filtered_X, + filtered_y=filtered_y, + X_bounds=X_bounds, + current_X=current_X, + max_samples_for_model=max_samples_for_model, + improvement=improvement, + num_points_to_try=num_points_to_try, + test_X=test_X + ) + +def next_sample_gp( + filtered_X: npt.ArrayLike, + filtered_y: npt.ArrayLike, + X_bounds: Optional[npt.ArrayLike] = None, + current_X: Optional[npt.ArrayLike] = None, + nu: floating = 1.5, + max_samples_for_model: integer = 100, + improvement: floating = 0.01, + num_points_to_try: integer = 1000, + opt_func: str = "expected_improvement", + test_X: Optional[npt.ArrayLike] = None, +) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: # build the acquisition function gp, y_mean, y_stddev, = train_gaussian_process( - filtered_X, filtered_y, X_bounds, current_X, nu, max_samples_for_gp + filtered_X, filtered_y, X_bounds, current_X, nu, max_samples_for_model ) # Look for the minimum value of our fitted-target-function + (kappa * fitted-target-std_dev) if test_X is None: # this is the usual case @@ -306,6 +346,74 @@ def next_sample( suggested_X_expected_improvement, ) +def fit_1D_parzen_estimator(X, x_min, x_max): + mus = X.copy() + sorted_mus = np.sort(mus) + extended_mus = np.concatenate((np.array([x_min]), sorted_mus, np.array([x_max]))) + sigmas = np.maximum(extended_mus[2:]-extended_mus[1:-1], extended_mus[1:-1] - extended_mus[0:-2]) + sigmas = np.maximum(sigmas, 1e-6) + return (mus, sigmas) + +def sample_from_1D_parzen_estimator(mus, sigmas, x_min, x_max, indices): +# which_mu = np.argmax(np.multinomial(1, [1.0/len(mus)]*len(mus)) + new_samples = np.zeros(len(indices)) + + # For which_mu == -1, sample from the (uniform) prior + new_samples[indices == -1] = np.random.default_rng().uniform(x_min, x_max, np.sum(indices == -1)) + # Other samples are from mus +# which_mu = which_mu[which_mu >=0] + new_samples[indices >= 0] = np.random.default_rng().normal(loc = mus[indices[indices >= 0]], + scale = sigmas[indices[indices >= 0]]) + return np.clip(new_samples, x_min, x_max) + +def llik_from_1D_parzen_estimator(samples, mus, sigmas, x_min, x_max): + llik = np.array(len(samples)) + samp_norm = (np.tile(samples, [len(mus), 1]).T - mus) / sigmas + llik = np.log((np.sum(scipy_stats.norm.pdf(samp_norm), axis=1) + 1.0 / (x_max - x_min)) / (len(mus) + 1.0)) + return llik + +def next_sample_tpe( + filtered_X: npt.ArrayLike, + filtered_y: npt.ArrayLike, + X_bounds: Optional[npt.ArrayLike] = None, + current_X: Optional[npt.ArrayLike] = None, + max_samples_for_model: integer = 100, + improvement: floating = 0.01, + num_points_to_try: integer = 1000, + test_X: Optional[npt.ArrayLike] = None, +) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: + + y_star = np.quantile(filtered_y, improvement) + if X_bounds is None: + hp_min = np.min(filtered_X, axis=0) + hp_max = np.max(filtered_X, axis=0) + X_bounds = np.column_stack(hp_min, hp_max) + + low_X = filtered_X[filtered_y <= y_star] + high_X = filtered_X[filtered_y > y_star] + num_hp = low_X.shape[1] + new_samples = np.zeros((num_points_to_try, num_hp)) + low_llik = np.zeros((num_points_to_try, num_hp)) + high_llik = np.zeros((num_points_to_try, num_hp)) + which_mu = np.random.default_rng().integers(-1, len(low_X), num_points_to_try) + # For each hyperparameter + for i in range(num_hp): + # Values below y_star + (x_min, x_max) = X_bounds[i] + (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:,i], x_min, x_max) + (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:,i], x_min, x_max) + new_samples[:, i] = sample_from_1D_parzen_estimator(low_mus, low_sigmas, x_min, x_max, which_mu) + low_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], low_mus, low_sigmas, x_min, x_max) + high_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], high_mus, high_sigmas, x_min, x_max) + + score = np.sum(low_llik - high_llik, axis=1) + return ( + new_samples[np.argmax(score),:], + None, + None, + None, + None, + ) def bayes_search_next_run( runs: List[SweepRun], @@ -339,7 +447,15 @@ def bayes_search_next_run( if "metric" not in config: raise ValueError('Bayesian search requires "metric" section') - if config["method"] != "bayes": + if isinstance(config["method"],str): + if config["method"] == "bayes": + config["method"] = { "name": "bayes", "model": "gp"} + else: + raise ValueError("Invalid sweep configuration for bayes_search_next_run.") + elif isinstance(config["method"],dict): + if config["method"]["model"] not in ("gp", "tpe"): + raise ValueError("Invalid sweep configuration for bayes_search_next_run.") + else: raise ValueError("Invalid sweep configuration for bayes_search_next_run.") goal = config["metric"]["goal"] @@ -418,6 +534,7 @@ def bayes_search_next_run( sample_y=y, X_bounds=X_bounds, current_X=current_X if len(current_X) > 0 else None, + model=config["method"]["model"], improvement=minimum_improvement, ) diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index 10d61795..499cb8b7 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -98,6 +98,7 @@ def run_iterations( optimium: Optional[npt.ArrayLike] = None, atol: Optional[npt.ArrayLike] = 0.2, chunk_size: integer = 1, + model: str = "gp" ) -> Tuple[npt.ArrayLike, npt.ArrayLike]: if x_init is not None: @@ -173,14 +174,15 @@ def test_squiggle_convergence(): run_iterations(squiggle, [[0.0, 5.0]], 200, x_init, optimium=[3.6], atol=0.2) -def test_squiggle_convergence_to_maximum(): +@pytest.mark.parametrize("model",["gp","tpe"]) +def test_squiggle_convergence_to_maximum(model): # This test checks whether the bayes algorithm correctly explores the parameter space # we sample a ton of positive examples, ignoring the negative side def f(x): return -squiggle(x) x_init = np.random.uniform(0, 5, 1)[:, None] - run_iterations(f, [[0.0, 5.0]], 200, x_init, optimium=[2], atol=0.2) + run_iterations(f, [[0.0, 5.0]], 200, x_init, optimium=[2], atol=0.2, model=model) def test_nans(): @@ -206,7 +208,8 @@ def test_squiggle_int(): assert np.isclose(sample % 1, 0) -def test_iterations_rosenbrock(): +@pytest.mark.parametrize("model",["gp","tpe"]) +def test_iterations_rosenbrock(model): dimensions = 3 # x_init = np.random.uniform(0, 2, size=(1, dimensions)) x_init = np.zeros((1, dimensions)) @@ -218,6 +221,7 @@ def test_iterations_rosenbrock(): optimium=[1, 1, 1], atol=0.2, improvement=0.1, + model=model ) From 34869f6cc0bd58db392dbc9da8fd5fa859874206 Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Sat, 19 Jun 2021 10:19:33 -0400 Subject: [PATCH 02/13] Bugfix to make tests cover new tpe code --- tests/test_bayes_search.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index 499cb8b7..072ee79e 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -120,6 +120,7 @@ def run_iterations( X_bounds=bounds, current_X=sample_X, improvement=improvement, + model=model ) if sample_X is None: sample_X = np.array([sample]) From fe004bcc5489a89077eb056cba9820bd11eb1932 Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Sun, 20 Jun 2021 23:11:10 -0400 Subject: [PATCH 03/13] Added multidimensional Parzen window model --- bayes_search.py | 90 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 23 deletions(-) diff --git a/bayes_search.py b/bayes_search.py index 465d5960..f707880e 100644 --- a/bayes_search.py +++ b/bayes_search.py @@ -346,32 +346,58 @@ def next_sample_gp( suggested_X_expected_improvement, ) -def fit_1D_parzen_estimator(X, x_min, x_max): +def fit_1D_parzen_estimator(X, X_bounds): mus = X.copy() sorted_mus = np.sort(mus) - extended_mus = np.concatenate((np.array([x_min]), sorted_mus, np.array([x_max]))) + extended_mus = np.insert(X_bounds, 1, sorted_mus) sigmas = np.maximum(extended_mus[2:]-extended_mus[1:-1], extended_mus[1:-1] - extended_mus[0:-2]) sigmas = np.maximum(sigmas, 1e-6) return (mus, sigmas) -def sample_from_1D_parzen_estimator(mus, sigmas, x_min, x_max, indices): -# which_mu = np.argmax(np.multinomial(1, [1.0/len(mus)]*len(mus)) - new_samples = np.zeros(len(indices)) +def sample_from_parzen_estimator(mus, sigmas, X_bounds, num_samples): + which_mu = np.random.default_rng().integers(-1, len(mus), num_samples) + samples = np.zeros((num_samples, len(X_bounds))) + uniform_ind = (which_mu == -1) + num_uniform = np.count_nonzero(uniform_ind) + samples[uniform_ind] = np.random.default_rng().uniform( + np.tile(X_bounds[:,0],[num_uniform, 1]), + np.tile(X_bounds[:,1],[num_uniform, 1]) ) + normal_ind = (which_mu >= 0) + num_normal = np.count_nonzero(normal_ind) + samples[normal_ind] = np.random.default_rng().normal( + loc = mus[which_mu[normal_ind]], + scale = sigmas[which_mu[normal_ind]] ) + return np.clip(samples, X_bounds[:,0], X_bounds[:,1]) + +def sample_from_1D_parzen_estimator(mus, sigmas, x_min, x_max, num_points_to_try): + indices = np.random.default_rng().integers(-1, len(low_X), num_points_to_try) + new_samples = np.zeros(num_points_to_try) # For which_mu == -1, sample from the (uniform) prior new_samples[indices == -1] = np.random.default_rng().uniform(x_min, x_max, np.sum(indices == -1)) # Other samples are from mus -# which_mu = which_mu[which_mu >=0] new_samples[indices >= 0] = np.random.default_rng().normal(loc = mus[indices[indices >= 0]], scale = sigmas[indices[indices >= 0]]) return np.clip(new_samples, x_min, x_max) +def llik_from_parzen_estimator(samples, mus, sigmas, X_bounds): + samp_norm = (np.tile(samples, [len(mus), 1, 1]).transpose((1, 0, 2)) - mus) / sigmas + samp_norm = np.square(samp_norm) + normalization = (2.0 * np.pi) ** (-len(X_bounds) / 2.0) / np.prod(sigmas, axis=1) + pdf = normalization * np.exp(-0.5 * np.sum(samp_norm, axis = 2)) + uniform_pdf = 1.0 / np.prod(X_bounds[:,1] - X_bounds[:,0]) + mixture = (np.sum(pdf, axis = 1) + uniform_pdf) / (len(mus) + 1.0) + return np.log(mixture) + def llik_from_1D_parzen_estimator(samples, mus, sigmas, x_min, x_max): - llik = np.array(len(samples)) samp_norm = (np.tile(samples, [len(mus), 1]).T - mus) / sigmas llik = np.log((np.sum(scipy_stats.norm.pdf(samp_norm), axis=1) + 1.0 / (x_max - x_min)) / (len(mus) + 1.0)) return llik +def parzen_threshold(y, gamma): + low_ind = int(np.floor(gamma * np.sqrt(len(y)))) + return np.sort(y)[low_ind] + def next_sample_tpe( filtered_X: npt.ArrayLike, filtered_y: npt.ArrayLike, @@ -381,32 +407,50 @@ def next_sample_tpe( improvement: floating = 0.01, num_points_to_try: integer = 1000, test_X: Optional[npt.ArrayLike] = None, + fit_1D: Optional[bool] = False ) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: - y_star = np.quantile(filtered_y, improvement) + y_star = np.quantile(filtered_y, improvement) if X_bounds is None: hp_min = np.min(filtered_X, axis=0) hp_max = np.max(filtered_X, axis=0) X_bounds = np.column_stack(hp_min, hp_max) + else: + X_bounds = np.array(X_bounds) low_X = filtered_X[filtered_y <= y_star] high_X = filtered_X[filtered_y > y_star] num_hp = low_X.shape[1] - new_samples = np.zeros((num_points_to_try, num_hp)) - low_llik = np.zeros((num_points_to_try, num_hp)) - high_llik = np.zeros((num_points_to_try, num_hp)) - which_mu = np.random.default_rng().integers(-1, len(low_X), num_points_to_try) - # For each hyperparameter - for i in range(num_hp): - # Values below y_star - (x_min, x_max) = X_bounds[i] - (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:,i], x_min, x_max) - (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:,i], x_min, x_max) - new_samples[:, i] = sample_from_1D_parzen_estimator(low_mus, low_sigmas, x_min, x_max, which_mu) - low_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], low_mus, low_sigmas, x_min, x_max) - high_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], high_mus, high_sigmas, x_min, x_max) - - score = np.sum(low_llik - high_llik, axis=1) + # Fitting separate parzen estimators to each hyperparameter + if fit_1D: + new_samples = np.zeros((num_points_to_try, num_hp)) + low_llik = np.zeros((num_points_to_try, num_hp)) + high_llik = np.zeros((num_points_to_try, num_hp)) + for i in range(num_hp): + # Values below y_star + (x_min, x_max) = X_bounds[i] + (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:,i], x_min, x_max) + (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:,i], x_min, x_max) + new_samples[:, i] = sample_from_1D_parzen_estimator(low_mus, low_sigmas, x_min, x_max, num_points_to_try) + low_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], low_mus, low_sigmas, x_min, x_max) + high_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], high_mus, high_sigmas, x_min, x_max) + score = np.sum(low_llik - high_llik, axis=1) + # Fitting a multidimensional Parzen estimator + else: + low_mus = low_X.copy() + low_sigmas = np.zeros((len(low_X), num_hp)) + high_mus = high_X.copy() + high_sigmas = np.zeros((len(high_X), num_hp)) + + for i in range(num_hp): + (low_mus[:,i], low_sigmas[:,i]) = fit_1D_parzen_estimator(low_X[:,i], X_bounds[i]) + (high_mus[:,i], high_sigmas[:,i]) = fit_1D_parzen_estimator(high_X[:,i], X_bounds[i]) + + new_samples = sample_from_parzen_estimator(low_mus, low_sigmas, X_bounds, num_points_to_try) + low_llik = llik_from_parzen_estimator(new_samples, low_mus, low_sigmas, X_bounds) + high_llik = llik_from_parzen_estimator(new_samples, high_mus, high_sigmas, X_bounds) + score = low_llik - high_llik + return ( new_samples[np.argmax(score),:], None, From e6df0fa8ebf62efee0f3ae48420e7592a06bf66c Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Mon, 21 Jun 2021 04:32:32 -0600 Subject: [PATCH 04/13] Adaptive loss threshold that scales with sqrt of num samples --- bayes_search.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bayes_search.py b/bayes_search.py index f707880e..232e5a6d 100644 --- a/bayes_search.py +++ b/bayes_search.py @@ -395,8 +395,11 @@ def llik_from_1D_parzen_estimator(samples, mus, sigmas, x_min, x_max): return llik def parzen_threshold(y, gamma): - low_ind = int(np.floor(gamma * np.sqrt(len(y)))) - return np.sort(y)[low_ind] + num_low = int(np.ceil(gamma * np.sqrt(len(y)))) + low_ind = np.argsort(y)[0:num_low] + ret_val = np.array([False] * len(y)) + ret_val[low_ind] = True + return ret_val def next_sample_tpe( filtered_X: npt.ArrayLike, @@ -410,7 +413,7 @@ def next_sample_tpe( fit_1D: Optional[bool] = False ) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: - y_star = np.quantile(filtered_y, improvement) + low_ind = parzen_threshold(filtered_y, improvement) if X_bounds is None: hp_min = np.min(filtered_X, axis=0) hp_max = np.max(filtered_X, axis=0) @@ -418,8 +421,8 @@ def next_sample_tpe( else: X_bounds = np.array(X_bounds) - low_X = filtered_X[filtered_y <= y_star] - high_X = filtered_X[filtered_y > y_star] + low_X = filtered_X[low_ind] + high_X = filtered_X[np.logical_not(low_ind)] num_hp = low_X.shape[1] # Fitting separate parzen estimators to each hyperparameter if fit_1D: From a7a8e4a210b2b7aa127d46872b60705aae311405 Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Mon, 21 Jun 2021 15:40:53 -0600 Subject: [PATCH 05/13] Added magic formula bounds on sigma from reference implementation --- bayes_search.py | 149 ++++++++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 56 deletions(-) diff --git a/bayes_search.py b/bayes_search.py index 232e5a6d..3f502a64 100644 --- a/bayes_search.py +++ b/bayes_search.py @@ -249,30 +249,31 @@ def next_sample( None, ) - if model=="gp": - return next_sample_gp( - filtered_X=filtered_X, + if model == "tpe": + return next_sample_tpe( + filtered_X=filtered_X, filtered_y=filtered_y, X_bounds=X_bounds, current_X=current_X, - nu=nu, max_samples_for_model=max_samples_for_model, improvement=improvement, num_points_to_try=num_points_to_try, - opt_func=opt_func, - test_X=test_X - ) - elif model=="tpe": - return next_sample_tpe( - filtered_X=filtered_X, + test_X=test_X, + ) + else: # GP + return next_sample_gp( + filtered_X=filtered_X, filtered_y=filtered_y, X_bounds=X_bounds, current_X=current_X, + nu=nu, max_samples_for_model=max_samples_for_model, improvement=improvement, num_points_to_try=num_points_to_try, - test_X=test_X - ) + opt_func=opt_func, + test_X=test_X, + ) + def next_sample_gp( filtered_X: npt.ArrayLike, @@ -346,54 +347,77 @@ def next_sample_gp( suggested_X_expected_improvement, ) + def fit_1D_parzen_estimator(X, X_bounds): mus = X.copy() sorted_mus = np.sort(mus) extended_mus = np.insert(X_bounds, 1, sorted_mus) - sigmas = np.maximum(extended_mus[2:]-extended_mus[1:-1], extended_mus[1:-1] - extended_mus[0:-2]) - sigmas = np.maximum(sigmas, 1e-6) + sigmas = np.maximum( + extended_mus[2:] - extended_mus[1:-1], extended_mus[1:-1] - extended_mus[0:-2] + ) + + # Magic formula from reference implementation + prior_sigma = (X_bounds[1] - X_bounds[0]) / np.sqrt(12.0) + minsigma = prior_sigma / min(100.0, (1.0 + len(mus))) + sigmas = np.clip(sigmas, minsigma, prior_sigma) + return (mus, sigmas) + def sample_from_parzen_estimator(mus, sigmas, X_bounds, num_samples): which_mu = np.random.default_rng().integers(-1, len(mus), num_samples) samples = np.zeros((num_samples, len(X_bounds))) - uniform_ind = (which_mu == -1) + uniform_ind = which_mu == -1 num_uniform = np.count_nonzero(uniform_ind) samples[uniform_ind] = np.random.default_rng().uniform( - np.tile(X_bounds[:,0],[num_uniform, 1]), - np.tile(X_bounds[:,1],[num_uniform, 1]) ) - normal_ind = (which_mu >= 0) - num_normal = np.count_nonzero(normal_ind) + np.tile(X_bounds[:, 0], [num_uniform, 1]), + np.tile(X_bounds[:, 1], [num_uniform, 1]), + ) + normal_ind = which_mu >= 0 samples[normal_ind] = np.random.default_rng().normal( - loc = mus[which_mu[normal_ind]], - scale = sigmas[which_mu[normal_ind]] ) - return np.clip(samples, X_bounds[:,0], X_bounds[:,1]) + loc=mus[which_mu[normal_ind]], scale=sigmas[which_mu[normal_ind]] + ) + return np.clip(samples, X_bounds[:, 0], X_bounds[:, 1]) -def sample_from_1D_parzen_estimator(mus, sigmas, x_min, x_max, num_points_to_try): - indices = np.random.default_rng().integers(-1, len(low_X), num_points_to_try) + +def sample_from_1D_parzen_estimator(mus, sigmas, X_bounds, num_points_to_try): + indices = np.random.default_rng().integers(-1, len(mus), num_points_to_try) new_samples = np.zeros(num_points_to_try) # For which_mu == -1, sample from the (uniform) prior - new_samples[indices == -1] = np.random.default_rng().uniform(x_min, x_max, np.sum(indices == -1)) + new_samples[indices == -1] = np.random.default_rng().uniform( + X_bounds[0], X_bounds[1], np.sum(indices == -1) + ) # Other samples are from mus - new_samples[indices >= 0] = np.random.default_rng().normal(loc = mus[indices[indices >= 0]], - scale = sigmas[indices[indices >= 0]]) - return np.clip(new_samples, x_min, x_max) + new_samples[indices >= 0] = np.random.default_rng().normal( + loc=mus[indices[indices >= 0]], scale=sigmas[indices[indices >= 0]] + ) + return np.clip(new_samples, X_bounds[0], X_bounds[1]) + def llik_from_parzen_estimator(samples, mus, sigmas, X_bounds): samp_norm = (np.tile(samples, [len(mus), 1, 1]).transpose((1, 0, 2)) - mus) / sigmas + # alt_pdf = np.prod(scipy_stats.norm.pdf(samp_norm), axis=2) / np.prod(sigmas, axis=1) samp_norm = np.square(samp_norm) - normalization = (2.0 * np.pi) ** (-len(X_bounds) / 2.0) / np.prod(sigmas, axis=1) - pdf = normalization * np.exp(-0.5 * np.sum(samp_norm, axis = 2)) - uniform_pdf = 1.0 / np.prod(X_bounds[:,1] - X_bounds[:,0]) - mixture = (np.sum(pdf, axis = 1) + uniform_pdf) / (len(mus) + 1.0) + normalization = (2.0 * np.pi) ** (-len(X_bounds) / 2.0) / np.prod(sigmas, axis=1) + pdf = normalization * np.exp(-0.5 * np.sum(samp_norm, axis=2)) + uniform_pdf = 1.0 / np.prod(X_bounds[:, 1] - X_bounds[:, 0]) + mixture = (np.sum(pdf, axis=1) + uniform_pdf) / (len(mus) + 1.0) return np.log(mixture) -def llik_from_1D_parzen_estimator(samples, mus, sigmas, x_min, x_max): + +def llik_from_1D_parzen_estimator(samples, mus, sigmas, X_bounds): samp_norm = (np.tile(samples, [len(mus), 1]).T - mus) / sigmas - llik = np.log((np.sum(scipy_stats.norm.pdf(samp_norm), axis=1) + 1.0 / (x_max - x_min)) / (len(mus) + 1.0)) + llik = np.log( + ( + np.sum(scipy_stats.norm.pdf(samp_norm), axis=1) + + 1.0 / (X_bounds[1] - X_bounds[0]) + ) + / (len(mus) + 1.0) + ) return llik + def parzen_threshold(y, gamma): num_low = int(np.ceil(gamma * np.sqrt(len(y)))) low_ind = np.argsort(y)[0:num_low] @@ -401,6 +425,7 @@ def parzen_threshold(y, gamma): ret_val[low_ind] = True return ret_val + def next_sample_tpe( filtered_X: npt.ArrayLike, filtered_y: npt.ArrayLike, @@ -410,7 +435,7 @@ def next_sample_tpe( improvement: floating = 0.01, num_points_to_try: integer = 1000, test_X: Optional[npt.ArrayLike] = None, - fit_1D: Optional[bool] = False + fit_1D: Optional[bool] = False, ) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: low_ind = parzen_threshold(filtered_y, improvement) @@ -431,12 +456,17 @@ def next_sample_tpe( high_llik = np.zeros((num_points_to_try, num_hp)) for i in range(num_hp): # Values below y_star - (x_min, x_max) = X_bounds[i] - (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:,i], x_min, x_max) - (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:,i], x_min, x_max) - new_samples[:, i] = sample_from_1D_parzen_estimator(low_mus, low_sigmas, x_min, x_max, num_points_to_try) - low_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], low_mus, low_sigmas, x_min, x_max) - high_llik[:,i] = llik_from_1D_parzen_estimator(new_samples[:,i], high_mus, high_sigmas, x_min, x_max) + (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:, i], X_bounds[i]) + (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:, i], X_bounds[i]) + new_samples[:, i] = sample_from_1D_parzen_estimator( + low_mus, low_sigmas, X_bounds[i], num_points_to_try + ) + low_llik[:, i] = llik_from_1D_parzen_estimator( + new_samples[:, i], low_mus, low_sigmas, X_bounds[i] + ) + high_llik[:, i] = llik_from_1D_parzen_estimator( + new_samples[:, i], high_mus, high_sigmas, X_bounds[i] + ) score = np.sum(low_llik - high_llik, axis=1) # Fitting a multidimensional Parzen estimator else: @@ -446,21 +476,32 @@ def next_sample_tpe( high_sigmas = np.zeros((len(high_X), num_hp)) for i in range(num_hp): - (low_mus[:,i], low_sigmas[:,i]) = fit_1D_parzen_estimator(low_X[:,i], X_bounds[i]) - (high_mus[:,i], high_sigmas[:,i]) = fit_1D_parzen_estimator(high_X[:,i], X_bounds[i]) + (low_mus[:, i], low_sigmas[:, i]) = fit_1D_parzen_estimator( + low_X[:, i], X_bounds[i] + ) + (high_mus[:, i], high_sigmas[:, i]) = fit_1D_parzen_estimator( + high_X[:, i], X_bounds[i] + ) - new_samples = sample_from_parzen_estimator(low_mus, low_sigmas, X_bounds, num_points_to_try) - low_llik = llik_from_parzen_estimator(new_samples, low_mus, low_sigmas, X_bounds) - high_llik = llik_from_parzen_estimator(new_samples, high_mus, high_sigmas, X_bounds) + new_samples = sample_from_parzen_estimator( + low_mus, low_sigmas, X_bounds, num_points_to_try + ) + low_llik = llik_from_parzen_estimator( + new_samples, low_mus, low_sigmas, X_bounds + ) + high_llik = llik_from_parzen_estimator( + new_samples, high_mus, high_sigmas, X_bounds + ) score = low_llik - high_llik return ( - new_samples[np.argmax(score),:], + new_samples[np.argmax(score), :], None, None, None, None, - ) + ) + def bayes_search_next_run( runs: List[SweepRun], @@ -494,16 +535,12 @@ def bayes_search_next_run( if "metric" not in config: raise ValueError('Bayesian search requires "metric" section') - if isinstance(config["method"],str): - if config["method"] == "bayes": - config["method"] = { "name": "bayes", "model": "gp"} - else: - raise ValueError("Invalid sweep configuration for bayes_search_next_run.") - elif isinstance(config["method"],dict): - if config["method"]["model"] not in ("gp", "tpe"): + if isinstance(config["method"], dict): + model = config["method"]["model"] + if model not in ("gp", "tpe"): raise ValueError("Invalid sweep configuration for bayes_search_next_run.") else: - raise ValueError("Invalid sweep configuration for bayes_search_next_run.") + model = "gp" goal = config["metric"]["goal"] metric_name = config["metric"]["name"] From fc9e1693f84dd668d053a8ab84d778f9780141f4 Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Thu, 8 Jul 2021 12:34:07 -0600 Subject: [PATCH 06/13] Cleaned up tpe_multi method and added bandwidth multiplier --- bayes_search.py | 130 ++++++++++++++++++++++--------------- run.py | 2 +- tests/test_bayes_search.py | 71 +++++++++++++++++--- 3 files changed, 142 insertions(+), 61 deletions(-) diff --git a/bayes_search.py b/bayes_search.py index 3f502a64..c825613b 100644 --- a/bayes_search.py +++ b/bayes_search.py @@ -155,7 +155,8 @@ def next_sample( current_X: Optional[npt.ArrayLike] = None, nu: floating = 1.5, max_samples_for_model: integer = 100, - improvement: floating = 0.01, + improvement: floating = 0.1, + bw_multiplier=0.2, num_points_to_try: integer = 1000, opt_func: str = "expected_improvement", model: str = "gp", @@ -184,6 +185,8 @@ def next_sample( randomly from the sample_X and used to train the GP. improvement: floating, optional, default 0.1 amount of improvement to optimize for -- higher means take more exploratory risks + bw_multiplier: floating, optional, default 0.2 + scaling factor for kernel density estimation bandwidth for tpe_multi algorithm num_points_to_try: integer, optional, default 1000 number of X values to try when looking for value with highest expected probability of improvement @@ -191,8 +194,8 @@ def next_sample( improvement of probability of improvement. Expected improvement is generally better - may want to remove probability of improvement at some point. (But I think prboability of improvement is a little easier to calculate) - model: one of {"gp", "tpe"} - whether to use a Gaussian Process as a surrogate model or - a Tree-structured Parzen Estimator + model: one of {"gp", "tpe", "tpe_multi"} - whether to use a Gaussian Process as a surrogate model, + a Tree-structured Parzen Estimator, or a multivariate TPE test_X: X values to test when looking for the best values to try Returns: @@ -259,6 +262,20 @@ def next_sample( improvement=improvement, num_points_to_try=num_points_to_try, test_X=test_X, + multivariate=False, + ) + elif model == "tpe_multi": + return next_sample_tpe( + filtered_X=filtered_X, + filtered_y=filtered_y, + X_bounds=X_bounds, + current_X=current_X, + max_samples_for_model=max_samples_for_model, + improvement=improvement, + num_points_to_try=num_points_to_try, + test_X=test_X, + multivariate=True, + bw_multiplier=bw_multiplier, ) else: # GP return next_sample_gp( @@ -348,34 +365,50 @@ def next_sample_gp( ) -def fit_1D_parzen_estimator(X, X_bounds): - mus = X.copy() - sorted_mus = np.sort(mus) - extended_mus = np.insert(X_bounds, 1, sorted_mus) - sigmas = np.maximum( +def fit_parzen_estimator_scott_bw(X, X_bounds, multiplier=1.06): + extended_X = np.insert(X_bounds.T, 1, X, axis=0) + mu = np.mean(extended_X, axis=0) + sumsqrs = np.sum(np.square(extended_X - mu), axis=0) + sigmahat = np.sqrt(sumsqrs / (len(extended_X) - 1)) + sigmas = multiplier * sigmahat * len(extended_X) ** (-1.0 / (4.0 + len(X_bounds))) + return np.tile(sigmas, [len(X), 1]) + + +def fit_1D_parzen_estimator_heuristic_bw(X, X_bounds): + sorted_ind = np.argsort(X.copy()) + sorted_mus = X[sorted_ind] + + # Treat endpoints of interval as data points + # extended_mus = np.insert(X_bounds, 1, sorted_mus) + + # Ignore endpoints of interval + extended_mus = np.insert([sorted_mus[0], sorted_mus[-1]], 1, sorted_mus) + + sigmas = np.zeros(len(X)) + sigmas[sorted_ind] = np.maximum( extended_mus[2:] - extended_mus[1:-1], extended_mus[1:-1] - extended_mus[0:-2] ) # Magic formula from reference implementation prior_sigma = (X_bounds[1] - X_bounds[0]) / np.sqrt(12.0) - minsigma = prior_sigma / min(100.0, (1.0 + len(mus))) + minsigma = prior_sigma / min(100.0, (1.0 + len(X))) sigmas = np.clip(sigmas, minsigma, prior_sigma) - return (mus, sigmas) + return sigmas def sample_from_parzen_estimator(mus, sigmas, X_bounds, num_samples): - which_mu = np.random.default_rng().integers(-1, len(mus), num_samples) + indices = np.random.default_rng().integers(-1, len(mus), num_samples) samples = np.zeros((num_samples, len(X_bounds))) - uniform_ind = which_mu == -1 + uniform_ind = indices == -1 num_uniform = np.count_nonzero(uniform_ind) samples[uniform_ind] = np.random.default_rng().uniform( np.tile(X_bounds[:, 0], [num_uniform, 1]), np.tile(X_bounds[:, 1], [num_uniform, 1]), ) - normal_ind = which_mu >= 0 + normal_ind = indices >= 0 samples[normal_ind] = np.random.default_rng().normal( - loc=mus[which_mu[normal_ind]], scale=sigmas[which_mu[normal_ind]] + loc=mus[indices[normal_ind]], scale=sigmas[indices[normal_ind]] ) return np.clip(samples, X_bounds[:, 0], X_bounds[:, 1]) @@ -397,7 +430,6 @@ def sample_from_1D_parzen_estimator(mus, sigmas, X_bounds, num_points_to_try): def llik_from_parzen_estimator(samples, mus, sigmas, X_bounds): samp_norm = (np.tile(samples, [len(mus), 1, 1]).transpose((1, 0, 2)) - mus) / sigmas - # alt_pdf = np.prod(scipy_stats.norm.pdf(samp_norm), axis=2) / np.prod(sigmas, axis=1) samp_norm = np.square(samp_norm) normalization = (2.0 * np.pi) ** (-len(X_bounds) / 2.0) / np.prod(sigmas, axis=1) pdf = normalization * np.exp(-0.5 * np.sum(samp_norm, axis=2)) @@ -410,7 +442,7 @@ def llik_from_1D_parzen_estimator(samples, mus, sigmas, X_bounds): samp_norm = (np.tile(samples, [len(mus), 1]).T - mus) / sigmas llik = np.log( ( - np.sum(scipy_stats.norm.pdf(samp_norm), axis=1) + np.sum(scipy_stats.norm.pdf(samp_norm) / sigmas, axis=1) + 1.0 / (X_bounds[1] - X_bounds[0]) ) / (len(mus) + 1.0) @@ -435,10 +467,10 @@ def next_sample_tpe( improvement: floating = 0.01, num_points_to_try: integer = 1000, test_X: Optional[npt.ArrayLike] = None, - fit_1D: Optional[bool] = False, + multivariate: Optional[bool] = False, + bw_multiplier: Optional[floating] = 1.0, ) -> Tuple[npt.ArrayLike, floating, floating, Optional[floating], Optional[floating]]: - low_ind = parzen_threshold(filtered_y, improvement) if X_bounds is None: hp_min = np.min(filtered_X, axis=0) hp_max = np.max(filtered_X, axis=0) @@ -446,42 +478,18 @@ def next_sample_tpe( else: X_bounds = np.array(X_bounds) + low_ind = parzen_threshold(filtered_y, improvement) low_X = filtered_X[low_ind] high_X = filtered_X[np.logical_not(low_ind)] - num_hp = low_X.shape[1] - # Fitting separate parzen estimators to each hyperparameter - if fit_1D: - new_samples = np.zeros((num_points_to_try, num_hp)) - low_llik = np.zeros((num_points_to_try, num_hp)) - high_llik = np.zeros((num_points_to_try, num_hp)) - for i in range(num_hp): - # Values below y_star - (low_mus, low_sigmas) = fit_1D_parzen_estimator(low_X[:, i], X_bounds[i]) - (high_mus, high_sigmas) = fit_1D_parzen_estimator(high_X[:, i], X_bounds[i]) - new_samples[:, i] = sample_from_1D_parzen_estimator( - low_mus, low_sigmas, X_bounds[i], num_points_to_try - ) - low_llik[:, i] = llik_from_1D_parzen_estimator( - new_samples[:, i], low_mus, low_sigmas, X_bounds[i] - ) - high_llik[:, i] = llik_from_1D_parzen_estimator( - new_samples[:, i], high_mus, high_sigmas, X_bounds[i] - ) - score = np.sum(low_llik - high_llik, axis=1) - # Fitting a multidimensional Parzen estimator - else: + num_hp = len(X_bounds) + if multivariate: low_mus = low_X.copy() low_sigmas = np.zeros((len(low_X), num_hp)) high_mus = high_X.copy() high_sigmas = np.zeros((len(high_X), num_hp)) - for i in range(num_hp): - (low_mus[:, i], low_sigmas[:, i]) = fit_1D_parzen_estimator( - low_X[:, i], X_bounds[i] - ) - (high_mus[:, i], high_sigmas[:, i]) = fit_1D_parzen_estimator( - high_X[:, i], X_bounds[i] - ) + low_sigmas = fit_parzen_estimator_scott_bw(low_X, X_bounds, bw_multiplier) + high_sigmas = fit_parzen_estimator_scott_bw(high_X, X_bounds) new_samples = sample_from_parzen_estimator( low_mus, low_sigmas, X_bounds, num_points_to_try @@ -493,9 +501,28 @@ def next_sample_tpe( new_samples, high_mus, high_sigmas, X_bounds ) score = low_llik - high_llik + best_sample = new_samples[np.argmax(score), :] + else: + # Fit separate 1D Parzen estimators to each hyperparameter + best_sample = np.zeros(num_hp) + for i in range(num_hp): + low_mus = low_X[:, i] + high_mus = high_X[:, i] + low_sigmas = fit_1D_parzen_estimator_heuristic_bw(low_mus, X_bounds[i]) + high_sigmas = fit_1D_parzen_estimator_heuristic_bw(high_mus, X_bounds[i]) + new_samples = sample_from_1D_parzen_estimator( + low_mus, low_sigmas, X_bounds[i], num_points_to_try + ) + low_llik = llik_from_1D_parzen_estimator( + new_samples, low_mus, low_sigmas, X_bounds[i] + ) + high_llik = llik_from_1D_parzen_estimator( + new_samples, high_mus, high_sigmas, X_bounds[i] + ) + best_sample[i] = new_samples[np.argmax(low_llik - high_llik)] return ( - new_samples[np.argmax(score), :], + best_sample, None, None, None, @@ -536,10 +563,11 @@ def bayes_search_next_run( raise ValueError('Bayesian search requires "metric" section') if isinstance(config["method"], dict): - model = config["method"]["model"] - if model not in ("gp", "tpe"): + model = config["method"]["bayes"]["model"] + if model not in ("gp", "tpe", "tpe_multi"): raise ValueError("Invalid sweep configuration for bayes_search_next_run.") else: + config["method"] = {"bayes": {"model": "gp"}} model = "gp" goal = config["metric"]["goal"] @@ -618,7 +646,7 @@ def bayes_search_next_run( sample_y=y, X_bounds=X_bounds, current_X=current_X if len(current_X) > 0 else None, - model=config["method"]["model"], + model=model, improvement=minimum_improvement, ) diff --git a/run.py b/run.py index 815d14c6..0ff7fd09 100644 --- a/run.py +++ b/run.py @@ -135,7 +135,7 @@ def next_run( return grid_search_next_run(runs, sweep_config, validate=validate, **kwargs) elif method == "random": return random_search_next_run(sweep_config, validate=validate) - elif method == "bayes": + elif method == "bayes" or isinstance(method, dict) and "bayes" in method.keys(): return bayes_search_next_run(runs, sweep_config, validate=validate, **kwargs) else: raise ValueError( diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index 072ee79e..7ce7fd95 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -22,6 +22,36 @@ def rosenbrock(x: npt.ArrayLike) -> np.floating: return np.sum((x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0) +rastrigin_A = 10.0 + + +def rastrigin(x: npt.ArrayLike) -> np.floating: + # has a global minimum at (1, 1, 1, ...) with x_i \in [-4.12, 6.12] + return rastrigin_A * len(x) + np.sum( + np.square(x - 1.0) - rastrigin_A * np.cos(2 * np.pi * (x - 1.0)) + ) + + +shekel_beta = np.array([0.1, 0.2, 0.2, 0.4, 0.4, 0.6, 0.3, 0.7, 0.5, 0.5]) +shekel_C = np.array( + [ + [4.0, 1.0, 8.0, 6.0, 3.0, 2.0, 5.0, 8.0, 6.0, 7.0], + [4.0, 1.0, 8.0, 6.0, 7.0, 9.0, 3.0, 1.0, 2.0, 3.6], + [4.0, 1.0, 8.0, 6.0, 3.0, 2.0, 5.0, 8.0, 6.0, 7.0], + [4.0, 1.0, 8.0, 6.0, 7.0, 9.0, 3.0, 1.0, 2.0, 3.6], + ] +).T +shekel_min = -10.5364 + + +def shekel(x: npt.ArrayLike) -> np.floating: + # Four minima in bounding region [[0,10]]*4 + return ( + -np.sum(np.reciprocal(np.sum(np.square(shekel_C - x), axis=1) + shekel_beta)) + - shekel_min + ) + + def run_bayes_search( f: Callable[[SweepRun], floating], config: SweepConfig, @@ -98,13 +128,15 @@ def run_iterations( optimium: Optional[npt.ArrayLike] = None, atol: Optional[npt.ArrayLike] = 0.2, chunk_size: integer = 1, - model: str = "gp" + model: str = "gp", + bw_multiplier: floating = 1.0, ) -> Tuple[npt.ArrayLike, npt.ArrayLike]: + bounds = np.array(bounds) if x_init is not None: X = x_init else: - X = [np.zeros(len(bounds))] + X = [np.random.uniform(low=bounds[:, 0], high=bounds[:, 1])] y = np.array([f(x) for x in X]).flatten() @@ -120,11 +152,13 @@ def run_iterations( X_bounds=bounds, current_X=sample_X, improvement=improvement, - model=model + model=model, + bw_multiplier=bw_multiplier, ) if sample_X is None: sample_X = np.array([sample]) - sample_X = np.append(sample_X, np.array([sample]), axis=0) + else: + sample_X = np.append(sample_X, np.array([sample]), axis=0) counter += 1 print( "X: {} prob(I): {} pred: {} value: {}".format( @@ -175,7 +209,7 @@ def test_squiggle_convergence(): run_iterations(squiggle, [[0.0, 5.0]], 200, x_init, optimium=[3.6], atol=0.2) -@pytest.mark.parametrize("model",["gp","tpe"]) +@pytest.mark.parametrize("model", ["gp", "tpe", "tpe_multi"]) def test_squiggle_convergence_to_maximum(model): # This test checks whether the bayes algorithm correctly explores the parameter space # we sample a ton of positive examples, ignoring the negative side @@ -183,7 +217,21 @@ def f(x): return -squiggle(x) x_init = np.random.uniform(0, 5, 1)[:, None] - run_iterations(f, [[0.0, 5.0]], 200, x_init, optimium=[2], atol=0.2, model=model) + if model == "tpe" or model == "tpe_multi": + improvement = 0.15 + else: + improvement = 0.1 + + run_iterations( + f, + [[0.0, 5.0]], + 200, + x_init, + improvement=improvement, + optimium=[2], + atol=0.2, + model=model, + ) def test_nans(): @@ -209,11 +257,16 @@ def test_squiggle_int(): assert np.isclose(sample % 1, 0) -@pytest.mark.parametrize("model",["gp","tpe"]) +@pytest.mark.parametrize("model", ["gp", "tpe", "tpe_multi"]) def test_iterations_rosenbrock(model): dimensions = 3 # x_init = np.random.uniform(0, 2, size=(1, dimensions)) x_init = np.zeros((1, dimensions)) + if model == "tpe" or model == "tpe_multi": + improvement = 0.15 + else: + improvement = 0.1 + run_iterations( rosenbrock, [[0.0, 2.0]] * dimensions, @@ -221,8 +274,8 @@ def test_iterations_rosenbrock(model): x_init, optimium=[1, 1, 1], atol=0.2, - improvement=0.1, - model=model + improvement=improvement, + model=model, ) From f48523b105703cc0cfc1dda2b80bcad6dce4e937 Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Wed, 4 Aug 2021 10:18:36 -0400 Subject: [PATCH 07/13] Updated JSON schema for TPE, added some tests for it --- config/schema.json | 34 ++++++++++++++++++++++++++++++---- tests/test_validation.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/config/schema.json b/config/schema.json index 0aded897..be83a935 100644 --- a/config/schema.json +++ b/config/schema.json @@ -520,10 +520,36 @@ }, "method": { "description": "Possible values: bayes, random, grid", - "enum": [ - "bayes", - "random", - "grid" + "oneOf": [ + { + "type": "string", + "enum": [ + "bayes", + "random", + "grid" + ] + }, + { + "type": "object", + "properties": { + "bayes": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model type for the Bayesian optimizer", + "enum": [ + "gp", + "tpe", + "tpe_multi" + ] + } + }, + "required": ["model"] + } + }, + "required": ["bayes"] + } ] }, "command": { diff --git a/tests/test_validation.py b/tests/test_validation.py index aedab2f7..f447e0da 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -38,6 +38,46 @@ def test_validation_disable(search_type): assert result is not None +def test_bayes_methods(): + schema = { + "method": {"bayes": {"model": "tpex"}}, + "metric": {"name": "loss", "goal": "minimize"}, + "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, + } + with pytest.raises(ValidationError): + SweepConfig(schema) + + schema = { + "method": {"bayes": "gp"}, + "metric": {"name": "loss", "goal": "minimize"}, + "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, + } + with pytest.raises(ValidationError): + SweepConfig(schema) + + schema = { + "method": "baye", + "metric": {"name": "loss", "goal": "minimize"}, + "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, + } + with pytest.raises(ValidationError): + SweepConfig(schema) + + schema = { + "method": {"bayes": {"model": "tpe"}}, + "metric": {"name": "loss", "goal": "minimize"}, + "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, + } + SweepConfig(schema) + + schema = { + "method": "bayes", + "metric": {"name": "loss", "goal": "minimize"}, + "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, + } + SweepConfig(schema) + + def test_validation_not_enough_params(): schema = {"method": "random", "parameters": {}} From b515ae22b89270e66f46c48bfdf64c317381a3ba Mon Sep 17 00:00:00 2001 From: Danny Goldstein Date: Sun, 3 Oct 2021 11:48:07 -0700 Subject: [PATCH 08/13] fix up some tests --- src/sweeps/bayes_search.py | 26 ++++++++++++++++------ src/sweeps/config/schema.json | 41 ++++++++--------------------------- src/sweeps/run.py | 4 ++-- tests/test_bayes_search.py | 16 +++++++------- tests/test_validation.py | 12 +++++----- tox.ini | 2 +- 6 files changed, 46 insertions(+), 55 deletions(-) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index f87133bd..4de2d2e8 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -486,7 +486,13 @@ def next_sample_tpe( test_X: Optional[ArrayLike] = None, multivariate: Optional[bool] = False, bw_multiplier: Optional[floating] = 1.0, -) -> Tuple[ArrayLike, floating, floating, Optional[floating], Optional[floating]]: +) -> Tuple[ + ArrayLike, + Optional[floating], + Optional[floating], + Optional[floating], + Optional[floating], +]: if X_bounds is None: hp_min = np.min(filtered_X, axis=0) @@ -579,13 +585,19 @@ def bayes_search_next_run( if "metric" not in config: raise ValueError('Bayesian search requires "metric" section') - if isinstance(config["method"], dict): - model = config["method"]["bayes"]["model"] - if model not in ("gp", "tpe", "tpe_multi"): - raise ValueError("Invalid sweep configuration for bayes_search_next_run.") - else: - config["method"] = {"bayes": {"model": "gp"}} + if "method" not in config: + raise ValueError("Method must be specified") + + if config["method"] == "bayes": model = "gp" + elif config["method"] == "bayes-tpe": + model = "tpe" + elif config["method"] == "bayes-tpe-multi": + model = "tpe_multi" + else: + raise ValueError( + 'Invalid method for bayes_search_next_run, must be one of "bayes", "bayes-tpe", "bayes-tpe-multi"' + ) goal = config["metric"]["goal"] metric_name = config["metric"]["name"] diff --git a/src/sweeps/config/schema.json b/src/sweeps/config/schema.json index 41b50084..0de2463e 100644 --- a/src/sweeps/config/schema.json +++ b/src/sweeps/config/schema.json @@ -566,38 +566,15 @@ "description": "The project for this sweep" }, "method": { - "description": "Possible values: bayes, random, grid", - "oneOf": [ - { - "type": "string", - "enum": [ - "bayes", - "random", - "grid", - "custom" - ] - }, - { - "type": "object", - "properties": { - "bayes": { - "type": "object", - "properties": { - "model": { - "type": "string", - "description": "The model type for the Bayesian optimizer", - "enum": [ - "gp", - "tpe", - "tpe_multi" - ] - } - }, - "required": ["model"] - } - }, - "required": ["bayes"] - } + "description": "Possible values: bayes, random, grid, bayes-tpe, bayes-tpe-multi, custom", + "type": "string", + "enum": [ + "bayes", + "random", + "grid", + "bayes-tpe", + "bayes-tpe-multi", + "custom" ] }, "command": { diff --git a/src/sweeps/run.py b/src/sweeps/run.py index 3c074284..67d949fc 100644 --- a/src/sweeps/run.py +++ b/src/sweeps/run.py @@ -191,13 +191,13 @@ def next_runs( ) elif method == "random": return random_search_next_runs(sweep_config, validate=validate, n=n) - elif method == "bayes": + elif method in ["bayes", "bayes-tpe", "bayes-tpe-multi"]: return bayes_search_next_runs( runs, sweep_config, validate=validate, n=n, **kwargs ) else: raise ValueError( - f'Invalid search type {method}, must be one of ["grid", "random", "bayes"]' + f'Invalid search type {method}, must be one of ["grid", "random", "bayes", "bayes-tpe", or "bayes-tpe-multi"]' ) diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index 0e8ce29d..d7d073c0 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -212,15 +212,15 @@ def test_squiggle_convergence(): run_iterations(squiggle, [[0.0, 5.0]], 200, x_init, optimium=[3.6], atol=0.2) -@pytest.mark.parametrize("model", ["gp", "tpe", "tpe_multi"]) -def test_squiggle_convergence_to_maximum(model): +@pytest.mark.parametrize("method", ["bayes", "bayes-tpe", "bayes-tpe-multi"]) +def test_squiggle_convergence_to_maximum(method): # This test checks whether the bayes algorithm correctly explores the parameter space # we sample a ton of positive examples, ignoring the negative side def f(x): return -squiggle(x) x_init = np.random.uniform(0, 5, 1)[:, None] - if model == "tpe" or model == "tpe_multi": + if method == "bayes-tpe" or method == "bayes-tpe-multi": improvement = 0.15 else: improvement = 0.1 @@ -233,7 +233,7 @@ def f(x): improvement=improvement, optimium=[2], atol=0.2, - model=model, + model=method, ) @@ -260,12 +260,12 @@ def test_squiggle_int(): assert np.isclose(sample % 1, 0) -@pytest.mark.parametrize("model", ["gp", "tpe", "tpe_multi"]) -def test_iterations_rosenbrock(model): +@pytest.mark.parametrize("method", ["bayes", "bayes-tpe", "bayes-tpe-multi"]) +def test_iterations_rosenbrock(method): dimensions = 3 # x_init = np.random.uniform(0, 2, size=(1, dimensions)) x_init = np.zeros((1, dimensions)) - if model == "tpe" or model == "tpe_multi": + if method == "bayes-tpe" or method == "bayes-tpe-multi": improvement = 0.15 else: improvement = 0.1 @@ -278,7 +278,7 @@ def test_iterations_rosenbrock(model): optimium=[1, 1, 1], atol=0.2, improvement=improvement, - model=model, + model=method, ) diff --git a/tests/test_validation.py b/tests/test_validation.py index 10b1cadc..43ea5423 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -8,7 +8,9 @@ from sweeps.hyperband_stopping import hyperband_stop_runs -@pytest.mark.parametrize("search_type", ["bayes", "grid", "random"]) +@pytest.mark.parametrize( + "search_type", ["bayes", "grid", "random", "bayes-tpe", "bayes-tpe-multi"] +) def test_validation_disable(search_type): invalid_schema = { "metric": {"name": "loss", "goal": "minimise"}, @@ -26,7 +28,7 @@ def test_validation_disable(search_type): _ = hyperband_stop_runs([], invalid_schema, validate=True) with pytest.raises(ValidationError): - if search_type == "bayes": + if "bayes" in search_type: _ = bayes_search_next_runs([], invalid_schema, validate=True) elif search_type == "grid": _ = grid_search_next_runs([], invalid_schema, validate=True) @@ -40,7 +42,7 @@ def test_validation_disable(search_type): def test_bayes_methods(): schema = { - "method": {"bayes": {"model": "tpex"}}, + "method": "bayes-tpex", "metric": {"name": "loss", "goal": "minimize"}, "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, } @@ -48,7 +50,7 @@ def test_bayes_methods(): SweepConfig(schema) schema = { - "method": {"bayes": "gp"}, + "method": "bayes-gp", "metric": {"name": "loss", "goal": "minimize"}, "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, } @@ -64,7 +66,7 @@ def test_bayes_methods(): SweepConfig(schema) schema = { - "method": {"bayes": {"model": "tpe"}}, + "method": "bayes-tpe", "metric": {"name": "loss", "goal": "minimize"}, "parameters": {"a": {"min": 0, "max": 1, "distribution": "uniform"}}, } diff --git a/tox.ini b/tox.ini index 0da0e6c8..5bb80c38 100644 --- a/tox.ini +++ b/tox.ini @@ -10,4 +10,4 @@ install_command = pip install {opts} {packages} commands = - pytest + pytest {posargs} From 8f22f0d912dad555a3beb1d5312bb19025f92107 Mon Sep 17 00:00:00 2001 From: Danny Goldstein Date: Sun, 3 Oct 2021 12:43:59 -0700 Subject: [PATCH 09/13] put _construct_model_data back in --- src/sweeps/bayes_search.py | 103 ++++++++++++++++++++----------------- tests/test_bayes_search.py | 2 +- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index 4de2d2e8..42fbb236 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -553,52 +553,9 @@ def next_sample_tpe( ) -def bayes_search_next_run( - runs: List[SweepRun], - config: Union[dict, SweepConfig], - validate: bool = False, - minimum_improvement: floating = 0.1, -) -> SweepRun: - """Suggest runs using Bayesian optimization. - - >>> suggestion = bayes_search_next_run([], { - ... 'method': 'bayes', - ... 'parameters': {'a': {'min': 1., 'max': 2.}}, - ... 'metric': {'name': 'loss', 'goal': 'maximize'} - ... }) - - Args: - runs: The runs in the sweep. - config: The sweep's config. - minimum_improvement: The minimium improvement to optimize for. Higher means take more exploratory risks. - validate: Whether to validate `sweep_config` against the SweepConfig JSONschema. - If true, will raise a Validation error if `sweep_config` does not conform to - the schema. If false, will attempt to run the sweep with an unvalidated schema. - - Returns: - The suggested run. - """ - - if validate: - config = SweepConfig(config) - - if "metric" not in config: - raise ValueError('Bayesian search requires "metric" section') - - if "method" not in config: - raise ValueError("Method must be specified") - - if config["method"] == "bayes": - model = "gp" - elif config["method"] == "bayes-tpe": - model = "tpe" - elif config["method"] == "bayes-tpe-multi": - model = "tpe_multi" - else: - raise ValueError( - 'Invalid method for bayes_search_next_run, must be one of "bayes", "bayes-tpe", "bayes-tpe-multi"' - ) - +def _construct_bayes_data( + runs: List[SweepRun], config: Union[dict, SweepConfig] +) -> Tuple[HyperParameterSet, ArrayLike, ArrayLike, ArrayLike]: goal = config["metric"]["goal"] metric_name = config["metric"]["name"] worst_func = min if goal == "maximize" else max @@ -611,8 +568,6 @@ def bayes_search_next_run( current_X: ArrayLike = [] y: ArrayLike = [] - X_bounds = [[0.0, 1.0]] * len(params.searchable_params) - # we calc the max metric to put as the metric for failed runs # so that our bayesian search stays away from them worst_metric: floating = np.inf if goal == "maximize" else -np.inf @@ -674,6 +629,58 @@ def bayes_search_next_run( # maximize, we need to negate y y *= -1 if goal == "maximize" else 1 + return params, sample_X, current_X, y + + +def bayes_search_next_run( + runs: List[SweepRun], + config: Union[dict, SweepConfig], + validate: bool = False, + minimum_improvement: floating = 0.1, +) -> SweepRun: + """Suggest runs using Bayesian optimization. + + >>> suggestion = bayes_search_next_run([], { + ... 'method': 'bayes', + ... 'parameters': {'a': {'min': 1., 'max': 2.}}, + ... 'metric': {'name': 'loss', 'goal': 'maximize'} + ... }) + + Args: + runs: The runs in the sweep. + config: The sweep's config. + minimum_improvement: The minimium improvement to optimize for. Higher means take more exploratory risks. + validate: Whether to validate `sweep_config` against the SweepConfig JSONschema. + If true, will raise a Validation error if `sweep_config` does not conform to + the schema. If false, will attempt to run the sweep with an unvalidated schema. + + Returns: + The suggested run. + """ + + if validate: + config = SweepConfig(config) + + if "metric" not in config: + raise ValueError('Bayesian search requires "metric" section') + + if "method" not in config: + raise ValueError("Method must be specified") + + if config["method"] == "bayes": + model = "gp" + elif config["method"] == "bayes-tpe": + model = "tpe" + elif config["method"] == "bayes-tpe-multi": + model = "tpe_multi" + else: + raise ValueError( + 'Invalid method for bayes_search_next_run, must be one of "bayes", "bayes-tpe", "bayes-tpe-multi"' + ) + + params, sample_X, current_X, y = _construct_bayes_data(runs, config) + X_bounds = [[0.0, 1.0]] * len(params.searchable_params) + ( suggested_X, suggested_X_prob_of_improvement, diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index d7d073c0..2f5fde68 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -927,7 +927,7 @@ def test_metric_extremum_in_bayes_search(): data_path = f"{os.path.dirname(__file__)}/data/ygnwe8ptupj33get.decoded.json" with open(data_path, "r") as f: data = json.load(f) - _, _, _, y = bayes._construct_gp_data( + _, _, _, y = bayes._construct_bayes_data( [SweepRun(**r) for r in data["jsonPayload"]["data"]["runs"]], data["jsonPayload"]["data"]["config"], ) From fc2dd26d294dca31384b519f7fddb137b2135ba2 Mon Sep 17 00:00:00 2001 From: Danny Goldstein Date: Sat, 23 Oct 2021 14:48:29 -0700 Subject: [PATCH 10/13] save work (tests may not pass) --- src/sweeps/bayes_search.py | 14 ++++---------- tests/test_bayes_search.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index 42fbb236..b6fb87c8 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -269,7 +269,7 @@ def next_sample( np.nan, ) - if model == "tpe": + if model == "bayes-tpe": return next_sample_tpe( filtered_X=filtered_X, filtered_y=filtered_y, @@ -281,7 +281,7 @@ def next_sample( test_X=test_X, multivariate=False, ) - elif model == "tpe_multi": + elif model == "bayes-tpe-multi": return next_sample_tpe( filtered_X=filtered_X, filtered_y=filtered_y, @@ -667,13 +667,7 @@ def bayes_search_next_run( if "method" not in config: raise ValueError("Method must be specified") - if config["method"] == "bayes": - model = "gp" - elif config["method"] == "bayes-tpe": - model = "tpe" - elif config["method"] == "bayes-tpe-multi": - model = "tpe_multi" - else: + if config["method"] not in ["bayes", "bayes-tpe", "bayes-tpe-multi"]: raise ValueError( 'Invalid method for bayes_search_next_run, must be one of "bayes", "bayes-tpe", "bayes-tpe-multi"' ) @@ -692,7 +686,7 @@ def bayes_search_next_run( sample_y=y, X_bounds=X_bounds, current_X=current_X if len(current_X) > 0 else None, - model=model, + model=config["method"], improvement=minimum_improvement, ) diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index 2f5fde68..c6b8110d 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -1,5 +1,6 @@ import os import json +import time from typing import Callable, Optional, Tuple, Iterable, Dict, Union import pytest @@ -157,6 +158,7 @@ def run_iterations( improvement=improvement, model=model, bw_multiplier=bw_multiplier, + max_samples_for_model=100, ) if sample_X is None: sample_X = np.array([sample]) @@ -213,7 +215,7 @@ def test_squiggle_convergence(): @pytest.mark.parametrize("method", ["bayes", "bayes-tpe", "bayes-tpe-multi"]) -def test_squiggle_convergence_to_maximum(method): +def test_squiggle_convergence_to_maximum(method, capsys): # This test checks whether the bayes algorithm correctly explores the parameter space # we sample a ton of positive examples, ignoring the negative side def f(x): @@ -225,16 +227,21 @@ def f(x): else: improvement = 0.1 + start = time.time() run_iterations( f, [[0.0, 5.0]], - 200, + 2000, x_init, improvement=improvement, optimium=[2], atol=0.2, model=method, ) + stop = time.time() + + with capsys.disabled(): + print(f"took {stop-start:0.2f} seconds") def test_nans(): @@ -261,7 +268,7 @@ def test_squiggle_int(): @pytest.mark.parametrize("method", ["bayes", "bayes-tpe", "bayes-tpe-multi"]) -def test_iterations_rosenbrock(method): +def test_iterations_rosenbrock(method, capsys): dimensions = 3 # x_init = np.random.uniform(0, 2, size=(1, dimensions)) x_init = np.zeros((1, dimensions)) @@ -270,6 +277,7 @@ def test_iterations_rosenbrock(method): else: improvement = 0.1 + start = time.time() run_iterations( rosenbrock, [[0.0, 2.0]] * dimensions, @@ -280,6 +288,10 @@ def test_iterations_rosenbrock(method): improvement=improvement, model=method, ) + stop = time.time() + + with capsys.disabled(): + print(f"took {stop-start:0.2f} seconds") def test_iterations_squiggle_chunked(): From ed2b8820cbd6e81d088a2a0d5e0d0033ae0cab6f Mon Sep 17 00:00:00 2001 From: Danny Goldstein Date: Wed, 27 Oct 2021 13:16:52 -0700 Subject: [PATCH 11/13] lint --- src/sweeps/bayes_search.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index b6fb87c8..8ad8c87e 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -320,7 +320,7 @@ def next_sample_gp( num_points_to_try: integer = 1000, opt_func: str = "expected_improvement", test_X: Optional[ArrayLike] = None, -) -> Tuple[ArrayLike, floating, floating, Optional[floating], Optional[floating]]: +) -> Tuple[ArrayLike, floating, floating, floating, floating]: # build the acquisition function gp, y_mean, y_stddev, = train_gaussian_process( filtered_X, filtered_y, X_bounds, current_X, nu, max_samples_for_model @@ -486,13 +486,7 @@ def next_sample_tpe( test_X: Optional[ArrayLike] = None, multivariate: Optional[bool] = False, bw_multiplier: Optional[floating] = 1.0, -) -> Tuple[ - ArrayLike, - Optional[floating], - Optional[floating], - Optional[floating], - Optional[floating], -]: +) -> Tuple[ArrayLike, floating, floating, floating, floating]: if X_bounds is None: hp_min = np.min(filtered_X, axis=0) @@ -507,9 +501,7 @@ def next_sample_tpe( num_hp = len(X_bounds) if multivariate: low_mus = low_X.copy() - low_sigmas = np.zeros((len(low_X), num_hp)) high_mus = high_X.copy() - high_sigmas = np.zeros((len(high_X), num_hp)) low_sigmas = fit_parzen_estimator_scott_bw(low_X, X_bounds, bw_multiplier) high_sigmas = fit_parzen_estimator_scott_bw(high_X, X_bounds) @@ -544,12 +536,13 @@ def next_sample_tpe( ) best_sample[i] = new_samples[np.argmax(low_llik - high_llik)] + # TODO: replace nans with actual values return ( best_sample, - None, - None, - None, - None, + np.nan, + np.nan, + np.nan, + np.nan, ) From ff65af2b648f17716bf0b16639552842ba63fcd2 Mon Sep 17 00:00:00 2001 From: Danny Goldstein Date: Wed, 27 Oct 2021 14:06:38 -0700 Subject: [PATCH 12/13] coerce arrays to lists in random_sample() --- src/sweeps/bayes_search.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index 8ad8c87e..271b5e59 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -57,6 +57,8 @@ def sigmoid(x: ArrayLike) -> ArrayLike: def random_sample(X_bounds: ArrayLike, num_test_samples: integer) -> ArrayLike: + if hasattr(X_bounds, "tolist"): + X_bounds = X_bounds.tolist() num_hyperparameters = len(X_bounds) test_X = np.empty((int(num_test_samples), num_hyperparameters)) for ii in range(num_test_samples): From d90015c380f7ada25644c6d8bca669dafa35867e Mon Sep 17 00:00:00 2001 From: Ash Blum Date: Tue, 30 Nov 2021 11:35:51 -0500 Subject: [PATCH 13/13] Compute expected improvement and prob of improvement from TPE --- src/sweeps/bayes_search.py | 44 +++++++++++++++++++++++++++++++------- tests/test_bayes_search.py | 6 +++--- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/sweeps/bayes_search.py b/src/sweeps/bayes_search.py index 271b5e59..bbbdfdc0 100644 --- a/src/sweeps/bayes_search.py +++ b/src/sweeps/bayes_search.py @@ -477,6 +477,24 @@ def parzen_threshold(y, gamma): return ret_val +def stats_from_parzen_estimator(low_llik, high_llik, y, gamma, low_ind): + gamma_rescaled = np.sum(low_ind)/len(y) + y_star = (np.min(y[np.logical_not(low_ind)]) + np.max(y[low_ind]))/2.0 + unnorm_mean_low_y = np.sum(y[low_ind]) / len(y) + unnorm_mean_high_y = np.mean(y) - unnorm_mean_low_y + unnorm_mean_low_y_sq = np.sum(np.square(y[low_ind])) / len(y) + unnorm_mean_high_y_sq = np.mean(np.square(y)) - unnorm_mean_low_y_sq + l_of_x = np.exp(low_llik) + g_of_x = np.exp(high_llik) + p_of_x = gamma_rescaled * l_of_x + (1-gamma_rescaled) * g_of_x + y_pred = (l_of_x * unnorm_mean_low_y + g_of_x * unnorm_mean_high_y) / p_of_x + y_sq_pred = (l_of_x * unnorm_mean_low_y_sq + g_of_x * unnorm_mean_high_y_sq) / p_of_x + y_std = np.sqrt(y_sq_pred - y_pred * y_pred) + prob_of_improvement = l_of_x * gamma_rescaled / p_of_x + expected_improvement = l_of_x * (gamma_rescaled * y_star - unnorm_mean_low_y) / p_of_x + return prob_of_improvement, y_pred, y_std, expected_improvement + + def next_sample_tpe( filtered_X: ArrayLike, filtered_y: ArrayLike, @@ -487,7 +505,7 @@ def next_sample_tpe( num_points_to_try: integer = 1000, test_X: Optional[ArrayLike] = None, multivariate: Optional[bool] = False, - bw_multiplier: Optional[floating] = 1.0, + bw_multiplier: Optional[floating] = 0.2, ) -> Tuple[ArrayLike, floating, floating, floating, floating]: if X_bounds is None: @@ -506,7 +524,7 @@ def next_sample_tpe( high_mus = high_X.copy() low_sigmas = fit_parzen_estimator_scott_bw(low_X, X_bounds, bw_multiplier) - high_sigmas = fit_parzen_estimator_scott_bw(high_X, X_bounds) + high_sigmas = fit_parzen_estimator_scott_bw(high_X, X_bounds, bw_multiplier) new_samples = sample_from_parzen_estimator( low_mus, low_sigmas, X_bounds, num_points_to_try @@ -518,10 +536,15 @@ def next_sample_tpe( new_samples, high_mus, high_sigmas, X_bounds ) score = low_llik - high_llik - best_sample = new_samples[np.argmax(score), :] + best_index = np.argmax(score) + best_sample = new_samples[best_index, :] + best_low_llik = low_llik[best_index] + best_high_llik = high_llik[best_index] else: # Fit separate 1D Parzen estimators to each hyperparameter best_sample = np.zeros(num_hp) + best_low_llik = 0 + best_high_llik = 0 for i in range(num_hp): low_mus = low_X[:, i] high_mus = high_X[:, i] @@ -536,15 +559,20 @@ def next_sample_tpe( high_llik = llik_from_1D_parzen_estimator( new_samples, high_mus, high_sigmas, X_bounds[i] ) - best_sample[i] = new_samples[np.argmax(low_llik - high_llik)] + best_index = np.argmax(low_llik - high_llik) + best_sample[i] = new_samples[best_index] + best_low_llik += low_llik[best_index] + best_high_llik += high_llik[best_index] + + (prob_of_improvement, predicted_y, predicted_std, expected_improvement) = stats_from_parzen_estimator(best_low_llik, best_high_llik, filtered_y, improvement, low_ind) # TODO: replace nans with actual values return ( best_sample, - np.nan, - np.nan, - np.nan, - np.nan, + prob_of_improvement, + predicted_y, + predicted_std, + expected_improvement, ) diff --git a/tests/test_bayes_search.py b/tests/test_bayes_search.py index c6b8110d..25d163c4 100644 --- a/tests/test_bayes_search.py +++ b/tests/test_bayes_search.py @@ -150,7 +150,7 @@ def run_iterations( for cc in range(chunk_size): if counter >= num_iterations: break - (sample, prob, pred, _, _,) = bayes.next_sample( + (sample, prob, pred, pred_std, exp_imp) = bayes.next_sample( sample_X=X, sample_y=y, X_bounds=bounds, @@ -166,8 +166,8 @@ def run_iterations( sample_X = np.append(sample_X, np.array([sample]), axis=0) counter += 1 print( - "X: {} prob(I): {} pred: {} value: {}".format( - sample, prob, pred, f(sample) + "X: {} prob(I): {} EI: {} pred: {} value: {} pred_std: {}".format( + sample, prob, exp_imp, pred, f(sample), pred_std ) )