-
Notifications
You must be signed in to change notification settings - Fork 180
/
Copy pathray_joblib.py
69 lines (51 loc) · 1.94 KB
/
ray_joblib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Optuna example that optimizes a classifier configuration for the Iris dataset using Ray with
joblib backend.
In this example, we optimize a classifier configuration for Iris dataset. The classifiers are from
scikit-learn. We optimize both the choice of classifier (between SVC and RandomForest) and their
hyperparameters.
You can execute this example as follows:
$ python ray_joblib.py
"""
import joblib
import optuna
import ray
from ray.util.joblib import register_ray
import sklearn.datasets
import sklearn.ensemble
import sklearn.model_selection
import sklearn.svm
# Get a URL for the Ray dashboard.
try:
ray.init(address="auto")
except ConnectionError:
ray.init()
ray.init(ignore_reinit_error=True)
register_ray()
def objective(trial):
iris = sklearn.datasets.load_iris()
x, y = iris.data, iris.target
classifier_name = trial.suggest_categorical("classifier", ["SVC", "RandomForest"])
if classifier_name == "SVC":
svc_c = trial.suggest_float("svc_c", 1e-10, 1e10, log=True)
classifier_obj = sklearn.svm.SVC(C=svc_c, gamma="auto")
else:
rf_max_depth = trial.suggest_int("rf_max_depth", 2, 32, log=True)
classifier_obj = sklearn.ensemble.RandomForestClassifier(
max_depth=rf_max_depth, n_estimators=10
)
score = sklearn.model_selection.cross_val_score(classifier_obj, x, y, n_jobs=-1, cv=3)
accuracy = score.mean()
return accuracy
if __name__ == "__main__":
study = optuna.create_study(direction="maximize")
with joblib.parallel_backend("ray", n_jobs=-1):
study.optimize(objective, n_trials=100)
print(f"Number of finished trials: {len(study.trials)}")
print(f"Elapsed time: {study.trials[-1].datetime_complete - study.trials[0].datetime_start}")
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")