文档章节

python 决策树 math库 c45算法

悲喜世界
 悲喜世界
发布于 2018/02/02 17:14
字数 1070
阅读 286
收藏 0

每周一搏,提升自我。

这段时间对python的应用,对python的理解越来越深。摸索中修改网上实例代码,有了自己的理解。

c45是ID3算法的升级版,比ID3高级。个人建议,用CART算法,感觉比C45好。

下面是c45代码,其中显示决策树结构的代码,下篇博文发布。

#!/usr/bin/python
#coding:utf-8

import operator
from math import log
import time
import os,sys
import string

#已文件为数据源
def createDataSet(trainDataFile):
	print trainDataFile
	dataSet=[]
	try:
		fin=open(trainDataFile)
		for line in fin:
			line=line.strip('\n')  #清除行皆为换行符
			cols=line.split(',')  #逗号分割行信息
			row =[cols[1],cols[2],cols[3],cols[4],cols[5],cols[6],cols[7],cols[8],cols[9],cols[10],cols[0]]
			dataSet.append(row)
			#print row
	except:
		print 'Usage xxx.py trainDataFilePath'
		sys.exit()
	labels=['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain']
	print 'dataSetlen',len(dataSet)
	return dataSet,labels

#c4.5 信息熵算法
def calcShannonEntOfFeature(dataSet,feat):
	numEntries=len(dataSet)
	labelCounts={}
	for feaVec in dataSet:
		currentLabel=feaVec[feat]
		if currentLabel not in labelCounts:
			labelCounts[currentLabel]=0
		labelCounts[currentLabel]+=1
	shannonEnt=0.0
	for key in labelCounts:
		prob=float(labelCounts[key])/numEntries
		shannonEnt-=prob * log(prob,2)
	return shannonEnt

def splitDataSet(dataSet,axis,value):
	retDataSet=[]
	for featVec in dataSet:
		if featVec[axis] ==value:
			reducedFeatVec=featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
	return retDataSet

def chooseBestFeatureToSplit(dataSet):
	numFeatures=len(dataSet[0])-1
	baseEntropy=calcShannonEntOfFeature(dataSet,-1)
	bestInfoGainRate=0.0
	bestFeature=-1
	for i in range(numFeatures):
		featList=[example[i] for example in dataSet]
		uniqueVals=set(featList)
		newEntropy=0.0
		for value in uniqueVals:
			subDataSet=splitDataSet(dataSet,i,value)
			prob=len(subDataSet) / float(len(dataSet))
			newEntropy+=prob * calcShannonEntOfFeature(subDataSet,-1)
		infoGain=baseEntropy- newEntropy
		iv = calcShannonEntOfFeature(dataSet,i)
		if(iv == 0):
			continue
		infoGainRate= infoGain /iv
		if infoGainRate > bestInfoGainRate:
			bestInfoGainRate = infoGainRate
			bestFeature = i
	return bestFeature

def majorityCnt(classList):
	classCount={}
	for vote in classList:
		if vote not in classCount.keys():
			classCount[vote]=0
		classCount[vote] +=1
	return max(classCount)


def createTree(dataSet,labels):
	classList= [example[-1] for example in dataSet]
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	if len(dataSet[0]) == 1:
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)
	bestFeatLabel = labels[bestFeat]
	if(bestFeat == -1): #特征一样,但类别不一样,即类别与特征不相关,随机选第一个类别分类结果
		return classList[0]
	myTree={bestFeatLabel:{}}
	del(labels[bestFeat])
	featValues =  [example[bestFeat] for example in dataSet]
	uniqueVals =set(featValues)
	for value in uniqueVals:
		subLabels = labels [:]
		myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
	return myTree


#创建简单的数据集   武器类型(0 步枪 1机枪),子弹(0 少 1多),血量(0 少,1多)  fight战斗 1逃跑 
def createDataSet():
	dataSet =[[1,1,0,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[0,0,1,'run'],[0,1,0,'fight'],[0,1,1,'run']]
	lables=['weapon','bullet','blood']
	return dataSet,lables

#按行打印数据集
def printData(myData):
	for item in myData:
		print '%s' %(item)

#使用决策树分类
def classify(inputTree,featLabels,testVec):
	firstStr=inputTree.keys()[0]
	secondDict=inputTree[firstStr]
	featIndex=featLabels.index(firstStr)
	for key in secondDict.keys():
		if testVec[featIndex] ==key:
			if type(secondDict[key]).__name__=='dict':
				classLabel=classify(secondDict[key],featLabels,testVec)
			else:classLabel=secondDict[key]
	return classLabel

#存储决策树
def storeTree(inputTree,filename):
	import pickle
	fw=open(filename,'w')
	pickle.dump(inputTree,fw)
	fw.close()


#获取决策树
def grabTree(filename):
	import pickle
	fr=open(filename)
	return pickle.load(fr)



def main():
	data,label =createDataSet()
	myTree=createTree(data,label)
	print(myTree)

	#打印决策树
	import showTree as show
	show.createPlot(myTree)


if __name__ == '__main__':
	main()

调用的showTree.py,内容如下:

#!/usr/bin/python
#coding:utf-8

import matplotlib.pyplot as plt

#决策树属性设置
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#createPlot 主函数,调用即可画出决策树,其中调用登了剩下的所有的函数,inTree的形式必须为嵌套的决策树
def createPlot(inThree):
	fig=plt.figure(1,facecolor='white')
	fig.clf()
	axprops=dict(xticks=[],yticks=[])
	createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  #no ticks
	# createPlot.ax1=plt.subplot(111,frameon=False)  #ticks for demo puropses
	plotTree.totalW=float(getNumLeafs(inThree))
	plotTree.totalD=float(getTreeDepth(inThree))
	plotTree.xOff=-0.5/plotTree.totalW;
	plotTree.yOff=1.0
	plotTree(inThree,(0.5,1.0),'')
	plt.show()

#决策树上节点之间的箭头设置
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
	createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
		xytext=centerPt,textcoords='axes fraction',
		va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

#决策树文字的添加位置和角度
def plotMidText(cntrPt,parentPt,txtString):
	xMid=(parentPt[0] -cntrPt[0])/2.0 +cntrPt[0]
	yMid=(parentPt[1] -cntrPt[1])/2.0 +cntrPt[1]
	createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)

#得到叶子节点的数量
def getNumLeafs(myTree):
	numLeafs=0
	firstStr=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
			numLeafs += getNumLeafs(secondDict[key])
		else: numLeafs+=1
	return numLeafs

#得到决策树的深度
def getTreeDepth(myTree):
	maxDepthh=0
	firstStr=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			thisDepth=1+getTreeDepth(secondDict[key])
		else: thisDepth=1
		if thisDepth>maxDepthh:maxDepthh=thisDepth
	return maxDepthh

#父子节点之间画决策树
def plotTree(myTree,parentPt,nodeTxt):
	numLeafs=getNumLeafs(myTree)
	depth=getTreeDepth(myTree)
	firstStr=myTree.keys()[0]
	cntrPt=(plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
	plotMidText(cntrPt,parentPt,nodeTxt)
	plotNode(firstStr,cntrPt,parentPt,decisionNode)
	secondDict=myTree[firstStr]
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key],cntrPt,str(key))
		else:
			plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
			plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
			plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
	plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

 

© 著作权归作者所有

共有 人打赏支持
悲喜世界
粉丝 3
博文 25
码字总数 14973
作品 0
海淀
程序员
私信 提问
加载中

评论(2)

悲喜世界
悲喜世界

引用来自“bjtbjt”的评论

import showTree as show
showTree是怎么来的呢
已追加
b
bjtbjt
import showTree as show
showTree是怎么来的呢
机器学习的最佳学习路线原来只有四步

AI这个词相信大家都非常熟悉,近几年来人工智能圈子格外热闹,光是AlphoGo就让大家对它刮目相看。今天小天就来跟大家唠一唠如何进军人工智能的第一步——机器学习。 在机器学习领域,Python已...

ufv59to8
2018/05/12
0
0
2018年某学院最新人工智能机器学习升级版视频教程

百度云盘下载 ==========课程目录============== └─视频 01 数学分析与概率论.mp4 02 数理统计与参数估计.avi 03 矩阵和线性代数.avi 04 凸优化.avi 05 Python库.avi 06 Python库II.mp4 07...

iyx668
2018/07/05
0
0
Kaggle实战之sklearn学习

今天刚刚接触python机器学习之kaggle实战这本书,初步学习了python机器学习库之sklearn的基本运用,照葫芦画瓢的对书中代码进行了一定的编写运行,小小记录我学机器学习之路 主要是这对支持向...

silencehhh
2018/04/17
0
0
Python入门到机器学习再到深入学习及应用整个学习系统

就在昨天我们收到了一位刚拿到Google offer的九章学员发来的截图 作为一名同是转专业到cs的程序猿,对此猿我定要表示万分真心的理解和祝贺! 其中滋味,唯吾猿类方懂… 此外这位细心的猿还找...

m68futkmurmtj
2018/04/24
0
0
决策树的C++实现(CART)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/fengbingchun/article/details/83042636 关于决策树的介绍可以参考: https://blog.csdn.net/fengbingchun/a...

fengbingchun
2018/10/13
0
0

没有更多内容

加载失败,请刷新页面

加载更多

刚入职阿里,告诉你真实的职场生活,兼谈P6、P7、P8的等级

一:拿下offer的人,基本上都有什么特征? 二:为什么选择阿里? 三:阿里的工作氛围什么样? 四:阿里的薪资情况? 五:阿里的晋升空间有多大? 最近部门招聘,很多工程师,包括我在内都参与...

java知识分子
15分钟前
1
0

中国龙-扬科
18分钟前
1
0
深入理解定时器系列第一篇——理解setTimeout和setInterval

很长时间以来,定时器一直是javascript动画的核心技术。但是,关于定时器,人们通常只了解如何使用setTimeout()和setInterval(),对它们的内在运行机制并不理解,对于与预想不同的实际运行状...

Jack088
21分钟前
2
0
windows 安装nvm

1、nvw-windows的官网:https://github.com/coreybutler/nvm-windows/releases 2、选择nvm-setup.zip安装 3、配置环境变量 4、检查nvm是否安装成功 使用管理员权限打开一个命令行。输入nvm v...

灰白发
31分钟前
1
0
MySQL

慢日志查询作用 慢日志查询的主要功能就是,记录sql语句中超过设定的时间阈值的查询语句。例如,一条查询sql语句,我们设置的阈值为1s,当这条查询语句的执行时间超过了1s,则将被写入到慢查...

士兵7
33分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部