python决策树可视化报错

def getNumLeafs(tree):
    numOfLeaf = 0
    firstNode = list(tree.keys())[0]
    second = tree[firstNode]
    # 测试节点的数据类型,若不是字典类型,则表示此节点为叶子节点
    for key in second.keys():
        if type(second[key]).__name__ == 'dict':
            numOfLeaf += getNumLeafs(second[key])
        else:
            numOfLeaf += 1
    return numOfLeaf
# 计算树的深度,在绘制决策树时确定y轴的高度
def getTreeDepth(tree):
    depthOfTree = 0
    firstNode = list(tree.keys())[0]
    second = tree[firstNode]
    for key in second.keys():
        if type(second[key]).__name__ == 'dict':
            thisNodeDepth = getTreeDepth(second[key]) + 1
        else:
            thisNodeDepth = 1
        if thisNodeDepth > depthOfTree:
            depthOfTree = thisNodeDepth
    return depthOfTree


# 用matplotlib绘制决策树
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth', fc='0.8')  # 决策节点;设置文本框的类型和文本框背景灰度,范围为0-1,0为黑,1为白,不设置默认为蓝色
leafNode = dict(boxstyle='round4', fc='1')  # 设置叶子节点文本框的属性
arrow_args = dict(arrowstyle='<-')


# 绘制节点
# annotate(text,xy,xycoords,xytext,textcoords,va,ha,bbox,arrowprops)
# xy表示进行标注的点的坐标
# xytext表示标注的文本信息的位置
# xycoords与textcoords分别为xy和xytext的说明,默认为data
# va,ha设置文本框中文字的位置,va表示竖直方向,ha表示水平方向
def plotNode(nodeTxt, nodeIndex, parentNodeIndex, nodeType):  # 形参:文本内容,文本的中心点,箭头指向文本的点,点的类型
    plt.annotate(nodeTxt, xy=parentNodeIndex, xycoords='axes fraction',
                 xytext=nodeIndex, textcoords='axes fraction',
                 va='center', ha='center', bbox=nodeType,
                 arrowprops=arrow_args)


# 在父子节点之间添加注释
def plotMidText(thisNodeIndex, parentNodeIndex, text):
    xmid = (parentNodeIndex[0] - thisNodeIndex[0]) / 2.0 + thisNodeIndex[0]
    ymid = (parentNodeIndex[1] - thisNodeIndex[1]) / 2.0 + thisNodeIndex[1]
    plt.text(xmid, ymid, text)  # 在指定位置添加注释


def plotTree(tree, parentNodeIndex, midTxt):
    global xOff
    global yOff
    numOfLeafs = getNumLeafs(tree)
    nodeTxt, = tree.keys()
    nodeIndex = (xOff + (1.0 + float(numOfLeafs)) / 2.0 / treeWidth, yOff)  # 计算节点的位置
    plotNode(nodeTxt, nodeIndex, parentNodeIndex, decisionNode)
    plotMidText(nodeIndex, parentNodeIndex, midTxt)
    secondDict = tree[nodeTxt]
    yOff = yOff - 1.0 / treeDepth
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], nodeIndex, str(key))
        else:
            xOff = xOff + 1.0 / treeWidth
            plotNode(secondDict[key], (xOff, yOff), nodeIndex, leafNode)
            plotMidText((xOff, yOff), nodeIndex, str(key))
    yOff = yOff + 1.0 / treeDepth


def createPlot(tree):  # 绘制决策树的主函数
    fig = plt.figure('DecisionTree', facecolor='white')  # 创建一个画布,命名为'decisionTree',画布颜色为白色
    fig.clf()  # 清空画布
    createPlot.ax1 = plt.subplot(111, frameon=False)  # 111:将画布分成1行1列,去第一块画布;frameon:是否绘制矩形坐标框
    # 设置两个全局变量xOff和yOff,追踪已绘制节点的位置,计算放置下一个节点的恰当位置。
    global xOff
    xOff = -0.5 / treeWidth
    global yOff
    yOff = 1.0
    plotTree(tree, (0.5, 1.0), '')
    plt.xticks([])
    plt.yticks([])
    plt.show()

def classify(inputTree, featureLabels, testVector):
    firstNode,=inputTree.keys()

    secondDict=inputTree[firstNode]
    featureIndex=featureLabels.index(firstNode)
    for key in secondDict.keys():
        if testVector[featureIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key], featureLabels, testVector)
            else:
                classLabel=secondDict[key]
    return classLabel


def storeTree(inputTree, filename):
    import pickle
    file=open(filename, 'wb')
    pickle.dump(inputTree, file)
    file.close()
def loadTree(filename):
    import pickle
    file = open(filename, 'rb')
    Tree = [pickle.load(file)]
    file.close()
    return Tree[0]


dataSet = [line.split() for line in open("feature.dat").readlines()]
labels = [line.split() for line in open("name.dat").readlines()]
decisionTree = createTree(dataSet, labels)
storeTree(decisionTree, 'decisionTree')

myTree = loadTree('decisionTree')
featureLabels = ['no surfacing', 'flippers']

treeWidth = float(getNumLeafs(myTree))
treeDepth = float(getTreeDepth(myTree))
createPlot(myTree)
print(classify(myTree, featureLabels, [1, 0]))

用matplotlib画构造的决策树报错,小白不知道什么原因,希望能帮忙改一下代码,谢谢

因为代码中对dataSet的处理过程中每行读取的数据都是一个字符串,所以无法进行后续处理,另外labels要减少一层嵌套。修改成如下:

dataSet=[]
for line in open("feature.dat").readlines()[:10]:
    dataSet.append([int(x) for x in list(line.strip().split(','))])
labels = [line.strip() for line in open("name.dat").readlines()]

 

float(getNumLeafs(myTree)) 的myTree参数应该是字典才对,现在myTree是字符串

 

您好,我是有问必答小助手,您的问题已经有小伙伴解答了,您看下是否解决,可以追评进行沟通哦~

如果有您比较满意的答案 / 帮您提供解决思路的答案,可以点击【采纳】按钮,给回答的小伙伴一些鼓励哦~~

ps:问答VIP仅需29元,即可享受5次/月 有问必答服务,了解详情>>>https://vip.csdn.net/askvip?utm_source=1146287632

非常感谢您使用有问必答服务,为了后续更快速的帮您解决问题,现诚邀您参与有问必答体验反馈。您的建议将会运用到我们的产品优化中,希望能得到您的支持与协助!

速戳参与调研>>>https://t.csdnimg.cn/Kf0y