文档章节

2017年2月19日 Decision Tree Classifier

airxiechao
 airxiechao
发布于 2017/03/20 08:34
字数 491
阅读 10
收藏 0

Decision Tree Classifier recursively generates rule to split data so as to minimize the impurity of each subset until every sample in subset belongs to the same class

from __future__ import division
import numpy as np
from sklearn.datasets import load_iris

from IPython.display import Image
from sklearn import tree
import pydotplus 

data = load_iris()
X = data.data
y = data.target

class myDecisionTreeClassifier():
    
    def gini(self, X, y, idx_data):
        if idx_data.shape[0] == 0:
            return 0
        else:
            p = 1
            for v in np.unique(y):
                y_sub = y[idx_data]
                p -= (y_sub[y_sub==v].shape[0] / idx_data.shape[0])**2

            return p
        
    def split_data(self, X, y, idx_data, idx_feat, val_feat):
        idx_left = idx_data[np.flatnonzero(X[idx_data][:, idx_feat] < val_feat)]
        idx_right = idx_data[np.flatnonzero(X[idx_data][:, idx_feat] >= val_feat)]
        return idx_left, idx_right
    
    def best_split_data(self, X, y, idx_data):
        igs = {}
        for f in range(X.shape[1]):
            for v in np.unique(X[:,f]):
                idx_left, idx_right = self.split_data(X, y, idx_data, f, v)
                gini_left = self.gini(X,y,idx_left)
                gini_right = self.gini(X,y,idx_right)
                igs[(f,v)] = (idx_left.shape[0]*gini_left + idx_right.shape[0]*gini_right) / idx_data.shape[0]

        idx_feat, val_feat = min(igs, key=igs.get)
        return idx_feat, val_feat
    
    def build_tree(self, X, y, idx_data):
        if idx_data.shape[0] == 0:
            return None

        node_tree = {
            'idx_feat': None,
            'val_feat': None,
            'node_left': None,
            'node_right': None,
            'target': None
        }

        if np.unique(y[idx_data]).shape[0] == 1:
            node_tree['target'] = np.unique(y[idx_data])[0]
            return node_tree

        idx_feat, val_feat = self.best_split_data(X, y, idx_data)
        node_tree['idx_feat'] = idx_feat
        node_tree['val_feat'] = val_feat

        idx_left, idx_right = self.split_data(X, y, idx_data, idx_feat, val_feat)
        node_tree['node_left'] = self.build_tree(X, y, idx_left)
        node_tree['node_right'] = self.build_tree(X, y, idx_right)
        return node_tree
        
    def fit(self, X, y):
        self.node_tree = self.build_tree(X, y, np.array(range(X.shape[0])))
    
    def predict_single(self, node_tree, x):
        target = node_tree['target']
        if target != None:
            return target

        idx_feat = node_tree['idx_feat']
        val_feat = node_tree['val_feat']
        node_left = node_tree['node_left']
        node_right = node_tree['node_right']

        if x[idx_feat] < val_feat:
            return self.predict_single(node_left, x)
        else:
            return self.predict_single(node_right, x)

    def predict(self, X):
        return np.array(map(lambda x: self.predict_single(self.node_tree, x), X))

    def score(self, X, y):
        return np.count_nonzero(self.predict(X) == y) / y.shape[0]
    
    def plot_tree_level(self, node_tree, level):
        idx_feat = node_tree['idx_feat']
        val_feat = node_tree['val_feat']
        node_left = node_tree['node_left']
        node_right = node_tree['node_right']
        target = node_tree['target']

        if level == 0:
            indent = '|--'
        else:
            indent =  '      '*level+'  |--'

        if idx_feat != None:
            print indent, data.feature_names[idx_feat], 'by', val_feat
        else:
            print indent, '[', data.target_names[target], ']'
            return

        self.plot_tree_level(node_left, level+1)
        self.plot_tree_level(node_right, level+1)
        
    def plot_tree(self):
        self.plot_tree_level(self.node_tree, 0)
        
idx_data = np.array(range(X.shape[0]))
dt = myDecisionTreeClassifier()
dt.fit(X,y)
print 'score:', dt.score(X,y)
dt.plot_tree()

#score: 1.0
#|-- petal width (cm) by 1.0
#        |-- [ setosa ]
#        |-- petal width (cm) by 1.8
#              |-- petal length (cm) by 5.0
#                    |-- petal width (cm) by 1.7
#                          |-- [ versicolor ]
#                          |-- [ virginica ]
#                    |-- petal width (cm) by 1.6
#                          |-- [ virginica ]
#                          |-- sepal length (cm) by 6.8
#                                |-- [ versicolor ]
#                                |-- [ virginica ]
#              |-- petal length (cm) by 4.9
#                    |-- sepal width (cm) by 3.1
#                          |-- [ virginica ]
#                          |-- [ versicolor ]
#                    |-- [ virginica ]

clf = tree.DecisionTreeClassifier()
clf.fit(X,y)

print 'scikit score:', np.count_nonzero(clf.predict(X) == y) / y.shape[0]
dot_data = tree.export_graphviz(clf, out_file=None, 
    feature_names=data.feature_names,  
    class_names=data.target_names,  
    filled=True, rounded=True,  
    special_characters=True)  
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png())
#scikit score: 1.0

© 著作权归作者所有

airxiechao
粉丝 4
博文 42
码字总数 9717
作品 1
成都
程序员
私信 提问
机器学习算法 --- Pruning (decision trees) & Random Forest Algorithm

一、Table for Content   在之前的文章中我们介绍了Decision Trees Agorithms,然而这个学习算法有一个很大的弊端,就是很容易出现Overfitting,为了解决此问题人们找到了一种方法,就是对...

码农47
2018/06/26
0
0
Applying decision trees

1: The Dataset In the past two missions, we learned about how decision trees are constructed. We used a modified version of ID3, which is a bit simpler than the most common tree......

Betty__
2016/09/29
7
0
在opencv3中的机器学习算法

转载:https://www.cnblogs.com/denny402/p/5032232.html 在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种: 1、正态贝叶斯:normal Bayessian classi...

byxdaz
05/15
0
0
Introduction to random forests

1: Introduction In the past three missions, we learned about decision trees, and looked at ways to reduce overfitting. The most powerful method to reduce decision tree overfitti......

Betty__
2016/09/29
12
0
Day186 | 遇见GCS(七)

7月份说超级节点计划马上开启,不过一个多月过去,还没有消息。 新闻:http://blockchain.game/news.html 开发者社区:http://developer.blockchain.game/ 新版官网上线:http://blockchain...

自由算法
2018/08/21
0
0

没有更多内容

加载失败,请刷新页面

加载更多

Spring Boot 2 实战:使用 Spring Boot Admin 监控你的应用

1. 前言 生产上对 Web 应用 的监控是十分必要的。我们可以近乎实时来对应用的健康、性能等其他指标进行监控来及时应对一些突发情况。避免一些故障的发生。对于 Spring Boot 应用来说我们可以...

码农小胖哥
今天
6
0
ZetCode 教程翻译计划正式启动 | ApacheCN

原文:ZetCode 协议:CC BY-NC-SA 4.0 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远。 ApacheCN 学习资源 贡献指南 本项目需要校对,欢迎大家提交 Pull Request。 ...

ApacheCN_飞龙
今天
4
0
CSS定位

CSS定位 relative相对定位 absolute绝对定位 fixed和sticky及zIndex relative相对定位 position特性:css position属性用于指定一个元素在文档中的定位方式。top、right、bottom、left属性则...

studywin
今天
7
0
从零基础到拿到网易Java实习offer,我做对了哪些事

作为一个非科班小白,我在读研期间基本是自学Java,从一开始几乎零基础,只有一点点数据结构和Java方面的基础,到最终获得网易游戏的Java实习offer,我大概用了半年左右的时间。本文将会讲到...

Java技术江湖
昨天
7
0
程序性能checklist

程序性能checklist

Moks角木
昨天
7
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部