-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathid3_decision_tree.py
170 lines (119 loc) · 4.71 KB
/
id3_decision_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
'''
Implementation of the ID3 decision tree algorithm.
Utilises some of the logic I learnt through Dataquest.
email: dat.nguyen at cantab.net
'''
import numpy as np
import math
import pandas as pd
class ID3tree(object):
def __init__(self):
self.tree = {}
self.nodes = []
def _calc_entropy(self, col):
'''
col: a Series, list, or numpy array.
Calculate entropy given a Series, list, or numpy array.
'''
counts = np.bincount(column) # unique value counts
probabilities = counts / len(column)
entropy = 0
for prob in probabilities:
if prob > 0:
entropy += prob * math.log(prob, 2)
return -entropy
def _calc_info_gain(self, data, split_name, target):
'''
data: dataset in a structured format (e.g., DataFrame)
split_name: name of the column (feature) to split on
target: target variable as a string
Calculate information gain given a dataset, column to split on, and target.
'''
assert isinstance(split_name, str)
assert isinstance(target, str)
original_entropy = self._calc_entropy(data[target])
column = data[split_name]
# split will occur on the median
median = column.median()
left_split = data[column <= median]
right_split = data[column > median]
weighted_entropies = 0
for subset in [left_split, right_split]:
prob = (subset.shape[0] / data.shape[0])
weighted_entropies += prob * self._calc_entropy(subset[target])
information_gain = original_entropy - weighted_entropies
return information_gain
def _find_split(self, data, target, features):
'''
data: dataset in a structured format (e.g., DataFrame)
target: target variable as a string
features: list of feature names
'''
info_gains = [self._calc_info_gain(data, x, target) for x in features]
best_column = features[np.argmax(info_gains)]
return best_column
def _train(self, data, target, features, tree):
'''
data: dataset in a structured format (e.g., DataFrame)
target: target variable as a string
features: list of feature names
tree: a dictionary representation of the tree
Train a decision tree using the ID3 algorithm.
'''
assert isinstance(tree, dict)
unique_targets = pd.unique(data[target])
self.nodes.append(len(self.nodes) + 1)
tree['number'] = self.nodes[-1]
# if singleton class, create a leaf node
if len(unique_targets) == 1:
if unique_targets[0] == 0:
tree['label'] = 0
else:
tree['label'] = 1
return
best_column = self._find_split(data, target, features)
column_median = data[best_column].median()
tree['split_feature'] = best_column
tree['median'] = column_median
left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
split_dict = {'left': left_split, 'right': right_split}
for key, split in split_dict.items():
tree[key] = {}
self.train(split, target, features, tree[key])
def fit(self, X, y, features=None):
'''
X: feature matrix
y: target variable as array-like
features: list of features
'''
# assert a previous tree has not been trained on this instance
assert not self.tree
if features is None:
features = X.columns
self.train(X, y, features, self.tree)
def predict(self, X_test):
# functional approach
def _predict_row(tree, row):
if "label" in tree:
return tree["label"]
column = tree["column"]
median = tree["median"]
if row[column] <= median:
return _predict_row(tree['left'], row)
elif row[column] > median:
return _predict_row(tree['right'], row)
predictions = X_test.apply(lambda x: _predict_row(self.tree, x), axis=1)
return predictions
# utils
def print_node(tree, depth):
if "label" in tree:
print_with_depth("Leaf: Label {0}".format(tree["label"]), depth)
return
print_with_depth("{0} > {1}".format(tree["split_feature"], tree["median"]), depth)
branches = [tree["left"], tree["right"]]
for branch in branches:
print_node(branch, depth+1)
def print_with_depth(string, depth):
prefix = " " * depth
print("{0}{1}".format(prefix, string))