-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDecesionTree2.py
250 lines (224 loc) · 9.76 KB
/
DecesionTree2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# -*- coding: utf-8 -*-
from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator
import re
# 计算数据集的基尼指数
def calcGini(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# 给所有可能分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
Gini = 1.0
# 以2为底数计算香农熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
Gini -= prob * prob
return Gini
# 对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 对连续变量划分数据集,direction规定划分的方向,
# 决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet, axis, value, direction):
retDataSet = []
for featVec in dataSet:
if direction == 0:
if featVec[axis] > value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
else:
if featVec[axis] <= value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet, labels):
numFeatures = len(dataSet[0]) - 1
bestGiniIndex = 100000.0
bestFeature = -1
bestSplitDict = {}
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
# 对连续型特征进行处理
if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
# 产生n-1个候选划分点
sortfeatList = sorted(featList)
splitList = []
for j in range(len(sortfeatList) - 1):
splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)
bestSplitGini = 10000
slen = len(splitList)
# 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
for j in range(slen):
value = splitList[j]
newGiniIndex = 0.0
subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
prob0 = len(subDataSet0) / float(len(dataSet))
newGiniIndex += prob0 * calcGini(subDataSet0)
prob1 = len(subDataSet1) / float(len(dataSet))
newGiniIndex += prob1 * calcGini(subDataSet1)
if newGiniIndex < bestSplitGini:
bestSplitGini = newGiniIndex
bestSplit = j
# 用字典记录当前特征的最佳划分点
bestSplitDict[labels[i]] = splitList[bestSplit]
GiniIndex = bestSplitGini
# 对离散型特征进行处理
else:
uniqueVals = set(featList)
newGiniIndex = 0.0
# 计算该特征下每种划分的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newGiniIndex += prob * calcGini(subDataSet)
GiniIndex = newGiniIndex
if GiniIndex < bestGiniIndex:
bestGiniIndex = GiniIndex
bestFeature = i
# 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
# 即是否小于等于bestSplitValue
if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
bestSplitValue = bestSplitDict[labels[bestFeature]]
labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
for i in range(shape(dataSet)[0]):
if dataSet[i][bestFeature] <= bestSplitValue:
dataSet[i][bestFeature] = 1
else:
dataSet[i][bestFeature] = 0
return bestFeature
# 特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
return max(classCount)
# 主程序,递归产生决策树
def createTree(dataSet, labels, data_full, labels_full):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet, labels)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
if type(dataSet[0][bestFeat]).__name__ == 'str':
currentlabel = labels_full.index(labels[bestFeat])
featValuesFull = [example[currentlabel] for example in data_full]
uniqueValsFull = set(featValuesFull)
del (labels[bestFeat])
# 针对bestFeat的每个取值,划分出一个子树。
for value in uniqueVals:
subLabels = labels[:]
if type(dataSet[0][bestFeat]).__name__ == 'str':
uniqueValsFull.remove(value)
myTree[bestFeatLabel][value] = createTree(splitDataSet \
(dataSet, bestFeat, value), subLabels, data_full, labels_full)
if type(dataSet[0][bestFeat]).__name__ == 'str':
for value in uniqueValsFull:
myTree[bestFeatLabel][value] = majorityCnt(classList)
return myTree
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
if '<=' in firstStr:
featvalue = float(re.compile("(<=.+)").search(firstStr).group()[2:])
featkey = re.compile("(.+<=)").search(firstStr).group()[:-2]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(featkey)
if testVec[featIndex] <= featvalue:
judge = 1
else:
judge = 0
for key in secondDict.keys():
if judge == int(key):
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
else:
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
# 测试决策树正确率
def testing(myTree, data_test, labels):
error = 0.0
for i in range(len(data_test)):
if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
error += 1
# print 'myTree %d' %error
return float(error)
# 测试投票节点正确率
def testingMajor(major, data_test):
error = 0.0
for i in range(len(data_test)):
if major != data_test[i][-1]:
error += 1
# print 'major %d' %error
return float(error)
# 后剪枝
def postPruningTree(inputTree, dataSet, data_test, labels):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
classList = [example[-1] for example in dataSet]
featkey = copy.deepcopy(firstStr)
if '<=' in firstStr:
featkey = re.compile("(.+<=)").search(firstStr).group()[:-2]
featvalue = float(re.compile("(<=.+)").search(firstStr).group()[2:])
labelIndex = labels.index(featkey)
temp_labels = copy.deepcopy(labels)
del (labels[labelIndex])
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
if type(dataSet[0][labelIndex]).__name__ == 'str':
inputTree[firstStr][key] = postPruningTree(secondDict[key], \
splitDataSet(dataSet, labelIndex, key),
splitDataSet(data_test, labelIndex, key),
copy.deepcopy(labels))
else:
inputTree[firstStr][key] = postPruningTree(secondDict[key], \
splitContinuousDataSet(dataSet, labelIndex, featvalue, key), \
splitContinuousDataSet(data_test, labelIndex, featvalue,
key), \
copy.deepcopy(labels))
if testing(inputTree, data_test, temp_labels) <= testingMajor(majorityCnt(classList), data_test):
return inputTree
return majorityCnt(classList)
df = pd.read_csv('csvtest.csv')
data = df.values[0:1941, 0:].tolist()
data_full = data[:]
data_test=df.values[1800:,0:].tolist()
labels = df.columns.values[0:-1].tolist()
labels_full = labels[:]
myTree = createTree(data, labels, data_full, labels_full)
print myTree
data = df.values[0:1941, 0:].tolist()
data_full = data[:]
data_test=df.values[1800:,0:].tolist()
myTree2 = postPruningTree(myTree, data, data_test, labels)