参考代码如下:
- 1 def plotMidText(cntrPt, parentPt, txtString) : 2 xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] 3 yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] 4 createPlot.ax1.text(xMid, yMid, txtString, va = "center", ha = "center", rotation = 30) 5 6 def plotTree(myTree, parentPt, nodeTxt) : #
- if the first key tells you what feat was split on 7 numLeafs = getNumLeafs(myTree)#this determines the x width of this tree 8 depth = getTreeDepth(myTree) 9 firstStr = myTree.keys()[0]#the text label
- for this node should be this 10 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) 11 plotMidText(cntrPt, parentPt, nodeTxt) 12 plotNode(firstStr, cntrPt, parentPt, decisionNode) 13 secondDict = myTree[firstStr] 14 plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD 15
- for key in secondDict.keys() : 16
- if type(secondDict[key]).__name__ == 'dict': #test to see
- if the nodes are dictonaires,
- if not they are leaf nodes 17 plotTree(secondDict[key], cntrPt, str(key))#recursion 18
- else: #it 's a leaf node print the leaf node
- 19 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
- 20 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
- 21 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
- 22 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
- 23 #if you do get a dictonary you know it's a tree,
- and the first element will be another dict 24 25 def createPlot(inTree) : 26 fig = plt.figure(1, facecolor = 'white') 27 fig.clf() 28 axprops = dict(xticks = [], yticks = []) 29 createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) 30 plotTree.totalW = float(getNumLeafs(inTree)) 31 plotTree.totalD = float(getTreeDepth(inTree)) 32 plotTree.xOff = -0.5 / plotTree.totalW;
- plotTree.yOff = 1.0;
- 33 plotTree(inTree, (0.5, 1.0), '') 34 plt.show()
第一个函数是在父子节点中填充文本信息,函数中是将父子节点的横纵坐标相加除以 2,上面写得有一点点不一样,但原理是一样的,然后还是在这个中间坐标的基础上添加文本,还是用的是 createPlot.ax1 这个全局变量,使用它的成员函数 text 来添加文本,里面是它的一些参数。
第二个函数是关键,它调用前面我们说过的函数,用树的宽度用于计算放置判断节点的位置 , 主要的计算原则是将它放在所有叶子节点的中间, 而不仅仅是它子节点的中间,根据高度就可以平分坐标系了,用坐标系的最大值除以高度,就是每层的高度。这个 plotTree 函数也是个递归函数,每次都是调用,画出一层,知道所有的分支都不是字典后,才算画完。每次检测出是叶子,就记录下它的坐标,并写出叶子的信息和父子节点间的信息。plotTree.xOff 和 plotTree.yOff 是用来追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。
第三个函数我们之前介绍介绍过一个类似,这个函数调用了 plotTree 函数,最后输出树状图,这里只说两点,一点是全局变量 plotTree.totalW 存储树的宽度 , 全 局变量 plotTree.totalD 存储树的深度,还有一点是 plotTree.xOff 和 plotTree.yOff 是在这个函数这里初始化的。
最后我们来测试一下
- cd桌面 / machinelearninginaction / Ch03
来源: http://www.cnblogs.com/fydeblog/p/7159775.html