决策树剪枝算法的实现相对而言比较平凡,只需要把算法依次翻译成程序语言即可
ID3、C4.5 剪枝算法的实现
回忆算法本身,可以知道我们需要获取“从下往上”这个顺序,为此我们需要先在CvDNode
中利用递归定义一个函数来更新 Tree 的self.layers
属性:
|
|
然后、在CvDBase
中定义一个对应的函数进行封装:
|
|
同时,为了做到合理的代码重用、我们可以先在CvDNode
中定义一个计算损失的函数:
|
|
有了以上两个函数,算法本身的实现就很直观了:
|
|
上述代码的第 25 行和第 28 行出现了 Node 的affected
属性,这是我们之前没有进行定义的(因为若在彼时定义会显得很突兀);不过由剪枝算法可知,这个属性的用处与其名字一致——标记一个 Node 是否是“被影响到的”Node。事实上,在一个 Node 进行了局部剪枝后,会有两类 Node “被影响到”:
- 该 Node 的子节点、子节点的子节点……等等,它们属于被剪掉的 Node、应该要将它们在
_old
、_tmp_nodes
中对应的位置从这些列表中除去 - 该 Node 的父节点、父节点的父节点……等等,它们存储叶节点的列表会因局部剪枝而发生改变、所以要更新
_old
和_mask
列表中对应位置的值
其中,我们之前定义的 Node 中是用pruned
属性来标记该 Node 是否已被剪掉、且介绍了如何通过递归来更新pruned
属性;affected
属性和pruned
属性的本质几乎没什么区别,所以我们同样可以通过递归来更新affected
属性。具体而言,我们只需:
- 在初始化时令
self.affected = False
- 在局部剪枝函数内部插入
_parent.affected = True
即可,其余部分可以保持不变。
CART 剪枝算法的实现
同样的,为了做到合理的代码重用、我们先利用之前实现的cost
函数、在CvDNode
里面定义一个获取 Node 阈值的函数:
|
|
由于算法本身的实现的思想以及用到的工具都和第一种剪枝算法大同小异、所以代码写起来也差不多:
|
|
代码第 4 行对根节点调用的cut_tree
方法同样是利用递归实现的:
|
|
然后就是最后一步、通过交叉验证选出最优树了。注意到之前我们封装生成算法时、最后一行调用了剪枝算法的封装——self.prune
方法。由于该方法是第一个接收了交叉验证集x_cv
和y_cv
的方法、所以我们应该让该方法来做交叉验证。简洁起见,我们直接选用加权正确率作为交叉验证的标准:
|
|