diff --git a/pylift/methods/base.py b/pylift/methods/base.py index 4b56430..b11bbeb 100644 --- a/pylift/methods/base.py +++ b/pylift/methods/base.py @@ -70,7 +70,7 @@ class BaseProxyMethod: randomized search scoring function. If a list of scoring_methods is passed, a dictionary can also be passed here, where the keys are the scoring_method strings and the values are the scoring cutoff for those - specific methods. + specific methods. sklearn_model : scikit-learn regressor Model used for grid searching and fitting. @@ -264,7 +264,7 @@ def __init__(self, df, transform_func, untransform_func, col_treatment='Treatmen 'estimator': self.sklearn_model(), **default_params } - self.bayes_search_params = { + self.bayes_search_params = { 'estimator': self.sklearn_model(), **default_params } @@ -366,15 +366,15 @@ def grid_search(self, **kwargs): def bayes_search(self, **kwargs): """ Grid search using skopt.BayesSearchCV - + Any parameters typically associated with BayesSearchCV (see - Scikit-Optimize documentation) can be passed as keyword arguments to + Scikit-Optimize documentation) can be passed as keyword arguments to this function. - + The final dictionary used for the grid search is saved to `self.bayes_search_params`. This is updated with any parameters that are passed. - + Examples -------- # Passing kwargs. @@ -384,7 +384,7 @@ def bayes_search(self, **kwargs): self.bayes_search_ = BayesSearchCV(**self.bayes_search_params) self.bayes_search_.fit(self.x_train, self.transformed_y_train) return self.bayes_search_ - + def fit(self, productionize=False, **kwargs): """A fit wrapper around any sklearn Regressor. @@ -594,4 +594,3 @@ def plot(self, plot_type='cgains', ax=None, n_bins=None, show_noise_fits=False, ax.plot([x[0], x[-1]], [bs_means[0], bs_means[-1]], '--', color=[0.6,0.6,0.6]) return ax -