剪枝算法的实现

(本文会用到的所有代码都在这里这里

决策树剪枝算法的实现相对而言比较平凡,只需要把算法依次翻译成程序语言即可

ID3、C4.5 剪枝算法的实现

回忆算法本身,可以知道我们需要获取“从下往上”这个顺序,为此我们需要先在CvDNode中利用递归定义一个函数来更新 Tree 的self.layers属性:

1
2
3
4
5
6
7
8
def update_layers(self):
# 根据该Node的深度、在self.layers对应位置的列表中记录自己
self.tree.layers[self._depth].append(self)
# 遍历所有子节点、完成递归
for _node in sorted(self.children):
_node = self.children[_node]
if _node is not None:
_node.update_layers()

然后、在CvDBase中定义一个对应的函数进行封装:

1
2
3
4
def _update_layers(self):
# 根据整颗决策树的高度、在self.layers里面放相应数量的列表
self.layers = [[] for _ in range(self.root.height)]
self.root.update_layers()

同时,为了做到合理的代码重用、我们可以先在CvDNode中定义一个计算损失的函数:

1
2
3
4
def cost(self, pruned=False):
if not pruned:
return sum([leaf["chaos"] * len(leaf["y"]) for leaf in self.leafs.values()])
return self.chaos * len(self._y)

有了以上两个函数,算法本身的实现就很直观了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def _prune(self):
self._update_layers()
_tmp_nodes = []
# 更新完决策树每一“层”的Node之后,从后往前地向 _tmp_nodes中加Node
for _node_lst in self.layers[::-1]:
for _node in _node_lst[::-1]:
if _node.category is None:
_tmp_nodes.append(_node)
_old = np.array([node.cost() + self.prune_alpha * len(node.leafs)
for node in _tmp_nodes])
_new = np.array([node.cost(pruned=True) + self.prune_alpha
for node in _tmp_nodes])
# 使用 _mask变量存储 _old和 _new对应位置的大小关系
_mask = _old >= _new
while True:
# 若只剩根节点就退出循环体
if self.root.height == 1:
return
p = np.argmax(_mask)
# 如果 _new中有比 _old中对应损失小的损失、则进行局部剪枝
if _mask[p]:
_tmp_nodes[p].prune()
# 根据被影响了的Node、更新 _old、_mask对应位置的值
for i, node in enumerate(_tmp_nodes):
if node.affected:
_old[i] = node.cost() + self.prune_alpha * len(node.leafs)
_mask[i] = _old[i] >= _new[i]
node.affected = False
# 根据被剪掉的Node、将各个变量对应的位置除去(注意从后往前遍历)
for i in range(len(_tmp_nodes) - 1, -1, -1):
if _tmp_nodes[i].pruned:
_tmp_nodes.pop(i)
_old = np.delete(_old, i)
_new = np.delete(_new, i)
_mask = np.delete(_mask, i)
else:
break
self.reduce_nodes()

上述代码的第 25 行和第 28 行出现了 Node 的affected属性,这是我们之前没有进行定义的(因为若在彼时定义会显得很突兀);不过由剪枝算法可知,这个属性的用处与其名字一致——标记一个 Node 是否是“被影响到的”Node。事实上,在一个 Node 进行了局部剪枝后,会有两类 Node “被影响到”:

  • 该 Node 的子节点、子节点的子节点……等等,它们属于被剪掉的 Node、应该要将它们在_old_tmp_nodes中对应的位置从这些列表中除去
  • 该 Node 的父节点、父节点的父节点……等等,它们存储叶节点的列表会因局部剪枝而发生改变、所以要更新_old_mask列表中对应位置的值

其中,我们之前定义的 Node 中是用pruned属性来标记该 Node 是否已被剪掉、且介绍了如何通过递归来更新pruned属性;affected属性和pruned属性的本质几乎没什么区别,所以我们同样可以通过递归来更新affected属性。具体而言,我们只需:

  • 在初始化时令self.affected = False
  • 在局部剪枝函数内部插入_parent.affected = True

即可,其余部分可以保持不变。

CART 剪枝算法的实现

同样的,为了做到合理的代码重用、我们先利用之前实现的cost函数、在CvDNode里面定义一个获取 Node 阈值的函数:

1
2
def get_threshold(self):
return (self.cost(pruned=True) - self.cost()) / (len(self.leafs) - 1)

由于算法本身的实现的思想以及用到的工具都和第一种剪枝算法大同小异、所以代码写起来也差不多:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def _cart_prune(self):
# 暂时将所有节点记录所属Tree的属性置为None
# 这样做的必要性会在后文进行说明
self.root.cut_tree()
_tmp_nodes = [node for node in self.nodes if node.category is None]
_thresholds = np.array([node.get_threshold() for node in _tmp_nodes])
while True:
# 利用deepcopy对当前根节点进行深拷贝、存入self.roots列表
# 如果前面没有把记录Tree的属性置为None,那么这里就也会对整个Tree做
# 深拷贝。可以想象、这样会引发严重的内存问题,速度也会被拖慢非常多
root_copy = deepcopy(self.root)
self.roots.append(root_copy)
if self.root.height == 1:
break
p = np.argmin(_thresholds)
_tmp_nodes[p].prune()
for i, node in enumerate(_tmp_nodes):
# 更新被影响到的Node的阈值
if node.affected:
_thresholds[i] = node.get_threshold()
node.affected = False
for i in range(len(_tmp_nodes) - 1, -1, -1):
# 去除掉各列表相应位置的元素
if _tmp_nodes[i].pruned:
_tmp_nodes.pop(i)
_thresholds = np.delete(_thresholds, i)
self.reduce_nodes()

代码第 4 行对根节点调用的cut_tree方法同样是利用递归实现的:

1
2
3
4
5
def cut_tree(self):
self.tree = None
for child in self.children.values():
if child is not None:
child.cut_tree()

然后就是最后一步、通过交叉验证选出最优树了。注意到之前我们封装生成算法时、最后一行调用了剪枝算法的封装——self.prune方法。由于该方法是第一个接收了交叉验证集x_cvy_cv的方法、所以我们应该让该方法来做交叉验证。简洁起见,我们直接选用加权正确率作为交叉验证的标准:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 定义计算加权正确率的函数
@staticmethod
def acc(y, y_pred, weights):
if weights is not None:
return np.sum((np.array(y) == np.array(y_pred)) * weights) / len(y)
return np.sum(np.array(y) == np.array(y_pred)) / len(y)
def prune(self, x_cv, y_cv, weights):
if self.root.is_cart:
# 如果该Node使用CART剪枝,那么只有在确实传入了交叉验证集的情况下
# 才能调用相关函数、否则没有意义
if x_cv is not None and y_cv is not None:
self._cart_prune()
_arg = np.argmax([CvDBase.acc(
y_cv, tree.predict(x_cv), weights) for tree in self.roots])
_tar_root = self.roots[_arg]
# 由于Node的feed_tree方法会递归地更新nodes属性、所以要先重置
self.nodes = []
_tar_root.feed_tree(self)
self.root = _tar_root
else:
self._prune()
观众老爷们能赏个脸么 ( σ'ω')σ