## python 决策树 math库 c45算法 原

悲喜世界

c45是ID3算法的升级版，比ID3高级。个人建议，用CART算法，感觉比C45好。

``````#!/usr/bin/python
#coding:utf-8

import operator
from math import log
import time
import os,sys
import string

#已文件为数据源
def createDataSet(trainDataFile):
print trainDataFile
dataSet=[]
try:
fin=open(trainDataFile)
for line in fin:
line=line.strip('\n')  #清除行皆为换行符
cols=line.split(',')  #逗号分割行信息
row =[cols[1],cols[2],cols[3],cols[4],cols[5],cols[6],cols[7],cols[8],cols[9],cols[10],cols[0]]
dataSet.append(row)
#print row
except:
print 'Usage xxx.py trainDataFilePath'
sys.exit()
labels=['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain']
print 'dataSetlen',len(dataSet)
return dataSet,labels

#c4.5 信息熵算法
def calcShannonEntOfFeature(dataSet,feat):
numEntries=len(dataSet)
labelCounts={}
for feaVec in dataSet:
currentLabel=feaVec[feat]
if currentLabel not in labelCounts:
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob * log(prob,2)
return shannonEnt

def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis] ==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=calcShannonEntOfFeature(dataSet,-1)
bestInfoGainRate=0.0
bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet) / float(len(dataSet))
newEntropy+=prob * calcShannonEntOfFeature(subDataSet,-1)
infoGain=baseEntropy- newEntropy
iv = calcShannonEntOfFeature(dataSet,i)
if(iv == 0):
continue
infoGainRate= infoGain /iv
if infoGainRate > bestInfoGainRate:
bestInfoGainRate = infoGainRate
bestFeature = i
return bestFeature

def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote] +=1
return max(classCount)

def createTree(dataSet,labels):
classList= [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
if(bestFeat == -1): #特征一样，但类别不一样，即类别与特征不相关，随机选第一个类别分类结果
return classList[0]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues =  [example[bestFeat] for example in dataSet]
uniqueVals =set(featValues)
for value in uniqueVals:
subLabels = labels [:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree

#创建简单的数据集   武器类型（0 步枪 1机枪），子弹（0 少 1多），血量（0 少，1多）  fight战斗 1逃跑
def createDataSet():
dataSet =[[1,1,0,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[0,0,1,'run'],[0,1,0,'fight'],[0,1,1,'run']]
lables=['weapon','bullet','blood']
return dataSet,lables

#按行打印数据集
def printData(myData):
for item in myData:
print '%s' %(item)

#使用决策树分类
def classify(inputTree,featLabels,testVec):
firstStr=inputTree.keys()[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] ==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featLabels,testVec)
else:classLabel=secondDict[key]
return classLabel

#存储决策树
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()

#获取决策树
def grabTree(filename):
import pickle
fr=open(filename)
return pickle.load(fr)

def main():
data,label =createDataSet()
myTree=createTree(data,label)
print(myTree)

#打印决策树
import showTree as show
show.createPlot(myTree)

if __name__ == '__main__':
main()``````

``````#!/usr/bin/python
#coding:utf-8

import matplotlib.pyplot as plt

#决策树属性设置
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

#createPlot 主函数，调用即可画出决策树，其中调用登了剩下的所有的函数，inTree的形式必须为嵌套的决策树
def createPlot(inThree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  #no ticks
# createPlot.ax1=plt.subplot(111,frameon=False)  #ticks for demo puropses
plotTree.totalW=float(getNumLeafs(inThree))
plotTree.totalD=float(getTreeDepth(inThree))
plotTree.xOff=-0.5/plotTree.totalW;
plotTree.yOff=1.0
plotTree(inThree,(0.5,1.0),'')
plt.show()

#决策树上节点之间的箭头设置
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

#决策树文字的添加位置和角度
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0] -cntrPt[0])/2.0 +cntrPt[0]
yMid=(parentPt[1] -cntrPt[1])/2.0 +cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)

#得到叶子节点的数量
def getNumLeafs(myTree):
numLeafs=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs

#得到决策树的深度
def getTreeDepth(myTree):
maxDepthh=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else: thisDepth=1
if thisDepth>maxDepthh:maxDepthh=thisDepth
return maxDepthh

#父子节点之间画决策树
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=myTree.keys()[0]
cntrPt=(plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
``````

© 著作权归作者所有

### 评论(2)

#### 引用来自“bjtbjt”的评论

import showTree as show
showTree是怎么来的呢

b
import showTree as show
showTree是怎么来的呢

AI这个词相信大家都非常熟悉，近几年来人工智能圈子格外热闹，光是AlphoGo就让大家对它刮目相看。今天小天就来跟大家唠一唠如何进军人工智能的第一步——机器学习。 在机器学习领域，Python已...

ufv59to8
2018/05/12
0
0
2018年某学院最新人工智能机器学习升级版视频教程

iyx668
2018/07/05
0
0
Kaggle实战之sklearn学习

silencehhh
2018/04/17
0
0
Python入门到机器学习再到深入学习及应用整个学习系统

m68futkmurmtj
2018/04/24
0
0

fengbingchun
2018/10/13
0
0

java知识分子
15分钟前
1
0

18分钟前
1
0

Jack088
21分钟前
2
0
windows 安装nvm

1、nvw-windows的官网：https://github.com/coreybutler/nvm-windows/releases 2、选择nvm-setup.zip安装 3、配置环境变量 4、检查nvm是否安装成功 使用管理员权限打开一个命令行。输入nvm v...

31分钟前
1
0
MySQL

33分钟前
1
0