title: 机器学习:决策树(decision tree)
date: 2019-11-16 15:23:53
mathjax: true
categories:

  • 机器学习
    tags:
  • 机器学习

什么是决策树?

  • 示例:

    1573889558948

    流程图就是一个决策树,长方形代表判断模块 (decision block),椭圆形代表终止模块(terminating block),表示已经得出结论,可以终止运行。 从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或者终止模块。图中构造了一个假想的邮件分类系统,它首先检测发送邮件域名地址。如果地址为 myEmployer.com,则将其放在分类“无聊时需要阅读的邮件”中。如果邮件不是来自这个域名, 则检查邮件内容里是否包含单词曲棍球,如果包含则将邮件归类到“需要及时处理的朋友邮件”, 如果不包含则将邮件归类到“无需阅读的垃圾邮件”

  • 理解的决策树:简单理解就是if elif else 的语句,判断判断再判断,直到能得到一个比较满意的label

构建决策树

  • 决策树:

    优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特

    征数据。

    缺点:可能会产生过度匹配问题

    适用数据类型:数值型和标称型

  • 创建分支的伪代码

    检测数据集中的每个子项是否属于同一分类: 
    If so return 类标签;
    Else 
     寻找划分数据集的最好特征
     划分数据集
     创建分支节点
    for 每个划分的子集
     调用函数createBranch并增加返回结果到分支节点中
    return 分支节点
    
  • 流程:

    收集数据:可以使用任何方法。
    准备数据:树构造算法 (这里使用的是ID3算法,只适用于标称型数据,这就是为什么数值型数据必须离散化。 还有其他的树构造算法,比如CART)
    分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
    训练算法:构造树的数据结构。
    测试算法:使用训练好的树计算错误率。
    使用算法:此步骤可以适用于任何监督学习任务,而使用决策树可以更好地理解数据的内在含义。
    
  • 举例模板:

    1573890211011

熵及其有关概念

  • 熵(entropy):指的是群体的混乱程度,我们的决策树的要求是在特征下将条件熵降到最低

    H = − ∑ i = 1 n p ( x i ) l o g 2 p ( x i ) H = - \sum^{n}_{i=1}p(x_i)log_2p(x_i) H=i=1np(xi)log2p(xi)

    其中p(xi)是选择该分类的概率

  • 划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,

    但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论

    是量化处理信息的分支科学

  • 信息增益:在划分数据集之前之后信息发生的变化称为信息增益,获得信息增益最高的特征就是最好的选择

    G a i n ( D , A ) = H ( D ) − H ( D ∣ A ) Gain(D,A) = H(D)-H(D|A) Gain(D,A)=H(D)H(DA)

  • 条件熵:这里我只给一个公式,主要是一个大哥讲的更好:网址

  • 香农熵(信息熵的计算):

    from math import log
    def calcShannonEnt(dataSet):
        # 求list的长度,表示计算参与训练的数据量
        numEntries = len(dataSet)
        # 计算分类标签label出现的次数
        labelCounts = {}
        # the the number of unique elements and their occurrence
        for featVec in dataSet:
            # 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签
            currentLabel = featVec[-1]
            # 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
    
        # 对于 label 标签的占比,求出 label 标签的香农熵
        shannonEnt = 0.0
        for key in labelCounts:
            # 使用所有类标签的发生频率计算类别出现的概率。
            prob = float(labelCounts[key])/numEntries
            # 计算香农熵,以 2 为底求对数
            shannonEnt -= prob * log(prob, 2)
        return shannonEnt
    
  • 基尼不纯度:

    讲解案例:

    一个随机事件X ,P(X=0)= 0.5 ,P(X=1)=0.5

      那么基尼不纯度就为   P(X=0)*(1 - P(X=0)) +  P(X=1)*(1 - P(X=1))  = 0.5
    
    
    
    
    
    
       一个随机事件Y ,P(Y=0)= 0.1 ,P(Y=1)=0.9
    
      那么基尼不纯度就为P(Y=0)*(1 - P(Y=0)) +   P(Y=1)*(1 -P(Y=1))  = 0.18
    
     很明显 X比Y更混乱,因为两个都为0.5 很难判断哪个发生。而Y就确定得多,Y=1发生的概率很大。而基尼不纯度也就越小。
    

    结论:

    (1)基尼不纯度可以作为 衡量系统混乱程度的 标准;

    (2)基尼不纯度越小,纯度越高,集合的有序程度越高,分类的效果越好;

    (3)基尼不纯度为 0 时,表示集合类别一致;

    (4)在决策树中,比较基尼不纯度的大小可以选择更好的决策条件(子节点)

    # 示例代码:
    my_data = [['fan', 'C', 'yes', 32, 'None'],
               ['fang', 'U', 'yes', 23, 'Premium'],
               ['ming', 'F', 'no', 28, 'Basic']]
    
    
    # 计算每一行数据的可能数量
    def uniqueCounts(rows):
        results = {}
        for row in rows:
            # 对最后一列的值计算
            # r = row[len(row) - 1]
            # 对倒数第三的值计算,也就是yes 和no 的一列
            r = row[len(row) - 3]
            if r not in results: results[r] = 0
            results[r] += 1
        return results
    
    
    # 基尼不纯度样例
    def giniImpurityExample(rows):
        total = len(rows)
        print(total)
        counts = uniqueCounts(rows)
        print(counts)
        imp = 0
        for k1 in counts:
            p1 = float(counts[k1]) / total
            print(counts[k1])
            for k2 in counts:
                if k1 == k2: continue
                p2 = float(counts[k2]) / total
                imp += p1 * p2
        return imp
    
    
    gini = giniImpurityExample(my_data)
    print('gini Impurity is %s' % gini)
    
  • 总结:

构建决策树

  • 数据集:

    def createDataSet():
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataSet, labels
    myDat,labels = createDataSet()
    

    这里我们使用特征数值化,label标称化,true是1,false是0,label是yes 或者no

    结果:

    1573890839726

    运行熵计算函数

    1573890948494

按照给定特征划分数据集

  • 代码:

    def splitDataSet(dataSet, index, value):
        """splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
            就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
        Args:
            dataSet 数据集                 待划分的数据集
            index 表示每一行的index列        划分数据集的特征
            value 表示index列对应的value值   需要返回的特征的值。
        Returns:
            index列为value的数据集【该数据集需要排除index列】
        """
        retDataSet = []
        for featVec in dataSet: 
            # index列为value的数据集【该数据集需要排除index列】
            # 判断index列的值是否为value
            if featVec[index] == value:
                # chop out index used for splitting
                # [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
                reducedFeatVec = featVec[:index]
                '''
                请百度查询一下: extend和append的区别
                music_media.append(object) 向列表中添加一个对象object
                music_media.extend(sequence) 把一个序列seq的内容添加到列表中 (跟 += 在list运用类似, music_media += sequence)
                1、使用append的时候,是将object看作一个对象,整体打包添加到music_media对象中。
                2、使用extend的时候,是将sequence看作一个序列,将这个序列和music_media序列合并,并放在其后面。
                music_media = []
                music_media.extend([1,2,3])
                print music_media
                #结果:
                #[1, 2, 3]
                
                music_media.append([4,5,6])
                print music_media
                #结果:
                #[1, 2, 3, [4, 5, 6]]
                
                music_media.extend([7,8,9])
                print music_media
                #结果:
                #[1, 2, 3, [4, 5, 6], 7, 8, 9]
                '''
                reducedFeatVec.extend(featVec[index+1:])
                # [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
                # 收集结果值 index列为value的行【该行需要排除index列】
                retDataSet.append(reducedFeatVec)
        return retDataSet
    
  • 运行示例:

    1573891166656

    其实上面划分数据集的,就是将数据中指定的索引等于指定值的数据提取出来

选择最好的数据集划分方式

  • 划分代码:

    #选择最好的数据集划分方式
    def chooseBestFeatureToSplit(dataSet):
        """chooseBestFeatureToSplit(选择最好的特征)
    
        Args:
            dataSet 数据集
        Returns:
            bestFeature 最优的特征列
        """
        # 求第一行有多少列的 Feature, 最后一列是label列嘛
        numFeatures = len(dataSet[0]) - 1
        # 数据集的原始信息熵
        baseEntropy = calcShannonEnt(dataSet)
        # 最优的信息增益值, 和最优的Featurn编号
        bestInfoGain, bestFeature = 0.0, -1
        # iterate over all the features
        for i in range(numFeatures):
            # create a list of all the examples of this feature
            # 获取对应的feature下的所有数据
            featList = [example[i] for example in dataSet]
            # get a set of unique values
            # 获取剔重后的集合,使用set对list数据进行去重
            uniqueVals = set(featList)
            # 创建一个临时的信息熵
            newEntropy = 0.0
            # 遍历某一列的value集合,计算该列的信息熵 
            # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                # 计算概率
                prob = len(subDataSet)/float(len(dataSet))
                # 计算条件熵
                newEntropy += prob * calcShannonEnt(subDataSet)
            # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
            # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
            infoGain = baseEntropy - newEntropy
            print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy)
            if (infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature
    

    这个代码中的条件熵的计算我自己认为是一个期望熵,就是按照这个特征分下去,总体的期望

  • 运行示例:

image

解释:其实就是按照第0个特征来分,总体的信息增益最大,也就是群体有序程度更高

递归创建决策树

  • 投票机制:vote,这是没有剪枝情况下,最坏情况下使用的投票机制,也就是没有找到完全划分的方式

    def majorityCnt(classList):
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            else:
                classCount[vote] += 1
            sortedClassCount = sorted(classCount.item(),key=lambda x:x[1],reverse = True)
            return sortedClassCount[0][0]
    
  • 创建树的代码

    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
        # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。
        # count() 函数是统计括号中的值在list中出现的次数
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
        # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
        if len(dataSet[0]) == 1:
            return majorityCnt(classList)
    
        # 选择最优的列,得到最优列对应的label含义
        bestFeat = chooseBestFeatureToSplit(dataSet)
        # 获取label的名称
        bestFeatLabel = labels[bestFeat]
        # 初始化myTree
        myTree = {bestFeatLabel: {}}
        # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
        # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
        del(labels[bestFeat])
        # 取出最优列,然后它的branch做分类
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            # 求出剩余的标签label
            subLabels = labels[:]
            # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree()
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
            # print 'myTree', value, myTree
        return myTree
    

    这里比较难理解,自己debug一下

  • 运行示例:

    1573892031215

  • 决策树测试代码:

    def classify(inputTree, featLabels, testVec):
        '''
        函数功能:
                对测试实例进行分类
        参数说明:
                inputTree__已经训练好的决策树
                featLabels__特征标签类别
                testVec__测试示例
        函数返回:
                分类结果  
        '''
        # python3.x中input.key()[0]返回的是dict_keys,不是list,这里注意区别(书上的代码是python2.x)
        firstStr = list(inputTree.keys())[0]  # 获得决策树第一个节点
        #print(featLabels)
        secondDict = inputTree[firstStr]      # 获取下一个字典
        print(secondDict)
        print(firstStr)
        featIndex = featLabels.index(firstStr)    # 将标签字符串转换为索引(第一个节点所在列的索引)
        for key in secondDict.keys():
            #print(testVec[featIndex])
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    
  • 运行示例

    1573892127137

matplotlib注解绘制树形图

  • matplotlib绘制图像示例

    #使用matploblib注解绘制树形图
    #使用matploblib注解绘制树形图
    import matplotlib.pyplot as plt
    
    from pylab import *
    mpl.rcParams['font.sans-serif'] = ['SimHei']#中文注释
    decisionNode = dict(boxstyle='sawtooth',fc="0.8")
    leafNode = dict(boxstyle="round4",fc="0.8")
    arrow_args = dict(arrowstyle="<-")
    def plotNode(nodeTxt,centerPt,parentPt,nodeType):
        createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords="axes fraction",xytext=centerPt,textcoords="axes fraction"
                               ,va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)# https://blog.csdn.net/leaf_zizi/article/details/82886755
    def createPlot():
        fig = plt.figure(1,facecolor="white")
        fig.clf
        createPlot.axl = plt.subplot(111,facecolor="white")
        plotNode(U"决策节点",(0.5,0.1),(0.1,0.5),decisionNode)
        plotNode(U"叶子节点",(0.8,0.1),(0.3,0.8),decisionNode)
        plt.show()
    #https://blog.csdn.net/u013038499/article/details/52449768
    

    444

决策树属性的描述

  • 叶子节点数目

    #获得叶子节点数目
    def getNumLeafs(myTree):
        '''
        函数功能:
                递归计算叶子节点数目
        函数参数:
                字典形式的决策树
        函数返回:
                叶子节点数目
        '''
        numLeafs = 0                                           # 初始化叶子节点的数目
        print(list(myTree.keys()))
        firstStr = list(myTree.keys())[0]                      # 获取决策树的第一个节点
        secondDict = myTree[firstStr]                          # 获取决策树的第二个节点  
        for key in secondDict.keys():                          
            if type(secondDict[key]).__name__ == 'dict':       # 若该节点为字典形式 
                numLeafs += getNumLeafs(secondDict[key])       # 若为字典,则递归计算新分支叶节点数 
            else:
                numLeafs += 1                                  # 若不是字典,则此节点为叶子节点 
        return numLeafs                                        # 返回叶子节点数目
    
    # 函数测试
    # labels = ['no surfacing', 'flippers', 'labels']
    # labels_copy2 = labels[:]
    # print(myDat)
    # print(labels)
    # tree = createTree(myDat,labels_copy2)
    
  • 树深度

    #得到树的深度
    def getTreeDepth(myTree):
        '''
        函数功能:
                递归计算决策树的深度
        函数参数:
                myTree__字典形式的决策树
        函数返回:
                决策树的最大深度
        '''
        maxDepth = 0                                            
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    #getTreeDepth(mytree)
    

树的标注

  • 标注使用详细

  • 文本标注

    #使用文本标注
    decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8')    # 设置中间节点的格式 
    leafNode = dict(boxstyle = 'round4', fc = '0.8')          # 设置叶子节点的格式
    arrow_args = dict(arrowstyle = '<-')                      # 定义箭头格式
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        '''
        函数功能:
                绘制节点
        参数说明:
                nodeTxt__节点名
                centerPt__文本位置
                parentPt__标注的箭头位置
                nodeType__节点格式
        '''
        createPlot.ax1.annotate(nodeTxt,                                          # 文本内容  
                               xy = parentPt, xycoords = 'axes fraction',         # 注释的起始位置,坐标系
                               xytext = centerPt, textcoords = 'axes fraction',   # 文本的起始位置
                               va = 'center', ha = 'center',                      # 水平对齐,垂直对齐
                               bbox = nodeType,                                   # 节点格式
                               arrowprops = arrow_args)                           # 箭头格式 
    
  • 边的标注

    #标注有向边
    def plotMidText(cntrPt, parentPt, txtString):
        '''
        函数功能:
                标注有向边内容
        参数说明:
                cntrpt、parentPt__计算标注位置
                txtString__标注内容
        '''
        xMid = (parentPt[0] - cntrPt[0])/ 2.0 + cntrPt[0]        # 计算文本位置的横坐标
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]         # 计算文本位置的纵坐标
        createPlot.ax1.text(xMid,                                # 文本位置的横坐标 
                            yMid,                                # 文本位置的纵坐标
                            txtString)                           # 标注内容
    
    
  • 绘制树

    def plotTree(myTree, parentPt, nodeTxt):
        '''
        函数功能:
                绘制决策树
        函数参数:
                myTree__决策树
                parentPt__标注的内容
                nodeTxt__节点名称
        '''
        numLeafs = getNumLeafs(myTree)                          # 获取决策树叶结点数目,决定了树的宽度 
        depth = getTreeDepth(myTree)                            # 获取决策树层数,决定了树的高度
        firstStr = next(iter(myTree))                           # 获得决策树第一个节点
        cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#确定中心位置
        plotMidText(cntrPt, parentPt, nodeTxt)                  # 标注有向边内容
        plotNode(firstStr, cntrPt, parentPt, decisionNode)      # 绘制节点
        secondDict = myTree[firstStr]                           # 获取下一个字典
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD     # y偏移值
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':        # 该结点是否为字典
                plotTree(secondDict[key], cntrPt, str(key))     # 如果是字典则不是叶结点,递归调用继续绘制
            else:                                
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW                          # x偏移值
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  # 绘制节点
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))                # 标注有向边内容
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD          
    
    
  • 画图:

    def createPlot(inTree):
        '''
        函数功能:绘制完整的决策树
        参数说明:
                inTree__决策树
        '''
        fig = plt.figure(1, facecolor='white')                      #创建画布
        fig.clf()                                                   #清空画布
        axprops = dict(xticks=[], yticks=[])   
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 除去x、y轴
        plotTree.totalW = float(getNumLeafs(inTree))                # 获取决策树叶结点数目
        plotTree.totalD = float(getTreeDepth(inTree))               # 获取决策树深度
        plotTree.xOff = -0.5/plotTree.totalW                        # x偏移的初始值
        plotTree.yOff = 1.0                                         # y偏移的初始值
        plotTree(inTree, (0.5,1.0), '')                             # 绘制决策树
        plt.show()                                                  #显示图像
    
    # 测试函数
    # labels = ['no surfacing', 'flippers', 'labels']
    # tree = createTree(dataSet, labels)
    createPlot(tree)
    
    

    结果:

保存树

def storeTree(inputTree, filename):
    '''
    函数功能:
            将决策树保存在磁盘中
    函数参数:
            inputTree__决策树
            filename__文件名
    '''
    import pickle                    # 导入pickle模块
    # 按照书中这里写的'w',将会报错write() argument must be str,not bytes
    # 所以这里将改写为'wb'
    fw = open(filename, 'wb')        # 创建一个可以“写入”的文件
    pickle.dump(inputTree, fw)       # pickle的dump函数将决策树写入文件中  
    fw.close()                       # 写完成后关闭文件
def gradTree(filename):
    '''
    函数功能:
            将树从磁盘中取出
    函数参数:
            filename__文件名
    '''
    import pickle                     # 导入pickle模块
    fr = open(filename, 'rb')         # 使用'rb'读出数据
    return pickle.load(fr)    

# 函数测试
# labels = ['no surfacing', 'flippers', 'labels']
# tree = createTree(dataSet, labels)
storeTree(tree, 'classifer.json')
#gradTree('classifer.json')

测试隐形演讲类型

  • 数据集

  • 代码:

    #查看决策树代码
    from math import log
    import matplotlib.pyplot as plt
    def calcShannonEnt(dataSet):
        # 求list的长度,表示计算参与训练的数据量
        numEntries = len(dataSet)
        # 计算分类标签label出现的次数
        labelCounts = {}
        # the the number of unique elements and their occurrence
        for featVec in dataSet:
            # 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签
            currentLabel = featVec[-1]
            # 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
    
        # 对于 label 标签的占比,求出 label 标签的香农熵
        shannonEnt = 0.0
        for key in labelCounts:
            # 使用所有类标签的发生频率计算类别出现的概率。
            prob = float(labelCounts[key])/numEntries
            # 计算香农熵,以 2 为底求对数
            shannonEnt -= prob * log(prob, 2)
        return shannonEnt
    #将某一个特征直接删除
    def splitDataSet(dataSet, index, value):
        """splitDataSet(通过遍历dataSet数据集,求出index对应的colnum列的值为value的行)
            就是依据index列进行分类,如果index列的数据等于 value的时候,就要将 index 划分到我们创建的新的数据集中
        Args:
            dataSet 数据集                 待划分的数据集
            index 表示每一行的index列        划分数据集的特征
            value 表示index列对应的value值   需要返回的特征的值。
        Returns:
            index列为value的数据集【该数据集需要排除index列】
        """
        retDataSet = []
        for featVec in dataSet:
            # index列为value的数据集【该数据集需要排除index列】
            # 判断index列的值是否为value
            if featVec[index] == value:
                # chop out index used for splitting
                # [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
                reducedFeatVec = featVec[:index]
                '''
                请百度查询一下: extend和append的区别
                music_media.append(object) 向列表中添加一个对象object
                music_media.extend(sequence) 把一个序列seq的内容添加到列表中 (跟 += 在list运用类似, music_media += sequence)
                1、使用append的时候,是将object看作一个对象,整体打包添加到music_media对象中。
                2、使用extend的时候,是将sequence看作一个序列,将这个序列和music_media序列合并,并放在其后面。
                music_media = []
                music_media.extend([1,2,3])
                print music_media
                #结果:
                #[1, 2, 3]
    
                music_media.append([4,5,6])
                print music_media
                #结果:
                #[1, 2, 3, [4, 5, 6]]
    
                music_media.extend([7,8,9])
                print music_media
                #结果:
                #[1, 2, 3, [4, 5, 6], 7, 8, 9]
                '''
                reducedFeatVec.extend(featVec[index + 1:])
                # [index+1:]表示从跳过 index 的 index+1行,取接下来的数据
                # 收集结果值 index列为value的行【该行需要排除index列】
                retDataSet.append(reducedFeatVec)
        return retDataSet
    #选择最好的数据集划分方式
    def chooseBestFeatureToSplit(dataSet):
        """chooseBestFeatureToSplit(选择最好的特征)
    
        Args:
            dataSet 数据集
        Returns:
            bestFeature 最优的特征列
        """
        # 求第一行有多少列的 Feature, 最后一列是label列嘛
        numFeatures = len(dataSet[0]) - 1
        # 数据集的原始信息熵
        baseEntropy = calcShannonEnt(dataSet)
        # 最优的信息增益值, 和最优的Featurn编号
        bestInfoGain, bestFeature = 0.0, -1
        # iterate over all the features
        for i in range(numFeatures):
            # create a list of all the examples of this feature
            # 获取对应的feature下的所有数据
            featList = [example[i] for example in dataSet]
            # get a set of unique values
            # 获取剔重后的集合,使用set对list数据进行去重
            uniqueVals = set(featList)
            # 创建一个临时的信息熵
            newEntropy = 0.0
            # 遍历某一列的value集合,计算该列的信息熵
            # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                # 计算概率
                prob = len(subDataSet)/float(len(dataSet))
                # 计算信息熵
                newEntropy += prob * calcShannonEnt(subDataSet)
            # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
            # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
            infoGain = baseEntropy - newEntropy
            print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy)
            if (infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature
    # 类似于KNN中的投票机制
    def majorityCnt(classList):
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            else:
                classCount[vote] += 1
            sortedClassCount = sorted(classCount.item(),key=lambda x:x[1],reverse = True)
            return sortedClassCount[0][0]
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
        # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。
        # count() 函数是统计括号中的值在list中出现的次数
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
        # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
        if len(dataSet[0]) == 1:
            return majorityCnt(classList)
    
        # 选择最优的列,得到最优列对应的label含义
        bestFeat = chooseBestFeatureToSplit(dataSet)
        # 获取label的名称
        #print(bestFeat)
        #print(labels)
        bestFeatLabel = labels[bestFeat]
        # 初始化myTree
        myTree = {bestFeatLabel: {}}
        # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
        # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
        del(labels[bestFeat])
        # 取出最优列,然后它的branch做分类
        print(labels)
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            # 求出剩余的标签label
            subLabels = labels[:]
            # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree(),字典里面还有小字典
            #print(dataSet)
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
            print(myTree)
            # print 'myTree', value, myTree
        return myTree
    def createDataSet():
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataSet, labels
    myDat,labels = createDataSet()
    labels_copy = labels[:]
    mytree = createTree(myDat,labels_copy)
    # print(mytree)
    # print(myDat)
    #tree = createTree(myDat, labels)
    print(labels)
    tree = createTree(myDat,labels)
    fr = open(r'C:\Users\admin\Desktop\machine-learning\ai\DecisionTree\3.DecisionTree\lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLable = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLable)
    #############################################################
    #获得叶子节点数目
    def getNumLeafs(myTree):
        '''
        函数功能:
                递归计算叶子节点数目
        函数参数:
                字典形式的决策树
        函数返回:
                叶子节点数目
        '''
        numLeafs = 0                                           # 初始化叶子节点的数目
        print(list(myTree.keys()))
        firstStr = list(myTree.keys())[0]                      # 获取决策树的第一个节点
        secondDict = myTree[firstStr]                          # 获取决策树的第二个节点
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':       # 若该节点为字典形式
                numLeafs += getNumLeafs(secondDict[key])       # 若为字典,则递归计算新分支叶节点数
            else:
                numLeafs += 1                                  # 若不是字典,则此节点为叶子节点
        return numLeafs                                        # 返回叶子节点数目
    
    # 函数测试
    # labels = ['no surfacing', 'flippers', 'labels']
    # labels_copy2 = labels[:]
    # print(myDat)
    # print(labels)
    # tree = createTree(myDat,labels_copy2)
    #得到树的深度
    def getTreeDepth(myTree):
        '''
        函数功能:
                递归计算决策树的深度
        函数参数:
                myTree__字典形式的决策树
        函数返回:
                决策树的最大深度
        '''
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    #使用文本标注
    decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8')    # 设置中间节点的格式
    leafNode = dict(boxstyle = 'round4', fc = '0.8')          # 设置叶子节点的格式
    arrow_args = dict(arrowstyle = '<-')                      # 定义箭头格式
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        '''
        函数功能:
                绘制节点
        参数说明:
                nodeTxt__节点名
                centerPt__文本位置
                parentPt__标注的箭头位置
                nodeType__节点格式
        '''
        createPlot.ax1.annotate(nodeTxt,                                          # 文本内容
                               xy = parentPt, xycoords = 'axes fraction',         # 注释的起始位置,坐标系
                               xytext = centerPt, textcoords = 'axes fraction',   # 文本的起始位置
                               va = 'center', ha = 'center',                      # 水平对齐,垂直对齐
                               bbox = nodeType,                                   # 节点格式
                               arrowprops = arrow_args)                           # 箭头格式
    #标注有向边
    def plotMidText(cntrPt, parentPt, txtString):
        '''
        函数功能:
                标注有向边内容
        参数说明:
                cntrpt、parentPt__计算标注位置
                txtString__标注内容
        '''
        xMid = (parentPt[0] - cntrPt[0])/ 2.0 + cntrPt[0]        # 计算文本位置的横坐标
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]         # 计算文本位置的纵坐标
        createPlot.ax1.text(xMid,                                # 文本位置的横坐标
                            yMid,                                # 文本位置的纵坐标
                            txtString)                           # 标注内容
    def plotTree(myTree, parentPt, nodeTxt):
        '''
        函数功能:
                绘制决策树
        函数参数:
                myTree__决策树
                parentPt__标注的内容
                nodeTxt__节点名称
        '''
        numLeafs = getNumLeafs(myTree)                          # 获取决策树叶结点数目,决定了树的宽度
        depth = getTreeDepth(myTree)                            # 获取决策树层数,决定了树的高度
        firstStr = next(iter(myTree))                           # 获得决策树第一个节点
        cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#确定中心位置
        plotMidText(cntrPt, parentPt, nodeTxt)                  # 标注有向边内容
        plotNode(firstStr, cntrPt, parentPt, decisionNode)      # 绘制节点
        secondDict = myTree[firstStr]                           # 获取下一个字典
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD     # y偏移值
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':        # 该结点是否为字典
                plotTree(secondDict[key], cntrPt, str(key))     # 如果是字典则不是叶结点,递归调用继续绘制
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW                          # x偏移值
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  # 绘制节点,y已经偏移
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))                # 标注有向边内容
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    def createPlot(inTree):
        '''
        函数功能:绘制完整的决策树
        参数说明:
                inTree__决策树
        '''
        fig = plt.figure(1, facecolor='white')                      #创建画布
        fig.clf()                                                   #清空画布
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 除去x、y轴
        plotTree.totalW = float(getNumLeafs(inTree))                # 获取决策树叶结点数目
        plotTree.totalD = float(getTreeDepth(inTree))               # 获取决策树深度
        plotTree.xOff = -0.5/plotTree.totalW                        # x偏移的初始值
        plotTree.yOff = 1.0                                         # y偏移的初始值
        plotTree(inTree, (0.5,1.0), '')                             # 绘制决策树,父节点(0.5,1.0)
        plt.show()                                                  #显示图像
    
    # 测试函数
    # labels = ['no surfacing', 'flippers', 'labels']
    # tree = createTree(dataSet, labels)
    # createPlot(tree)
    createPlot(lensesTree)
    
    

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐