-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRandomForest.py
38 lines (31 loc) · 1.32 KB
/
RandomForest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from DecisionTree import DecisionTree
import numpy as np
from collections import Counter
class RandomForest:
def __init__(self, n_trees=10, max_depth=10, min_samples_split=2, n_features=None) -> None:
self.n_trees = n_trees
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.n_features = n_features
self.trees = []
def fit(self, X, y):
for _ in range(self.n_trees):
tree = DecisionTree(max_depth = self.max_depth,
min_samples_split=self.min_samples_split,
n_features=self.n_features)
X_sample, y_sample = self._bootstrap_samples(X, y)
tree.fit(X_sample, y_sample)
self.trees.append(tree)
def _bootstrap_samples(self, X, y):
n_samples = X.shape[0]
idxs = np.random.choice(n_samples, n_samples, replace=True)
return X[idxs], y[idxs]
def _most_common_label(self, y):
counter = Counter(y)
label = counter.most_common(1)[0][0]
return label
def predict(self, X):
predictions = np.array([tree.predict(X) for tree in self.trees])
tree_preds = np.swapaxes(predictions, 0, 1)
predictions = np.array([self._most_common_label(pred) for pred in tree_preds])
return predictions