diff --git a/asreview2-optuna/classifiers.py b/asreview2-optuna/classifiers.py index 58c1135..fd37cc3 100644 --- a/asreview2-optuna/classifiers.py +++ b/asreview2-optuna/classifiers.py @@ -4,9 +4,10 @@ NaiveBayesClassifier, LogisticClassifier, SVMClassifier, - RandomForestClassifier, ) +from sklearn.ensemble import RandomForestClassifier + def naive_bayes_params(trial: optuna.trial.FrozenTrial): # Use logarithmic normal distribution for alpha (alpha effect is non-linear) @@ -52,9 +53,26 @@ def random_forest_params(trial: optuna.trial.FrozenTrial): } +class RFClassifier(RandomForestClassifier): + """Random forest classifier. + + Based on the sklearn implementation of the random forest + sklearn.ensemble.RandomForestClassifier. + """ + + name = "rf" + label = "Random forest" + + def __init__(self, n_estimators=100, max_features=10, **kwargs): + super().__init__( + n_estimators=int(n_estimators), + max_features=max_features, + **kwargs, + ) + classifiers = { "nb": NaiveBayesClassifier, "log": LogisticClassifier, "svm": SVMClassifier, - "rf": RandomForestClassifier, + "rf": RFClassifier, }