03 April 2012

平衡二叉树是一种绝对平衡的BST,它规定所有节点的左右子树高度相差不能大于1。这就要求每次新增、删除节点时,都需要重新计算左右子树树高,如果树高超过1就需要做节点的旋转,让它恢复到平衡状态。旋转方式全部四种。

左左旋转

如果由于往node.leftNode.leftNode插入新节点而导致不平衡,这时需要对node节点做左旋转。

"""\
    左左旋转
"""
def rotateLL(node):
    top = node.leftNode # 当前节点的左节点升级为top
    node.leftNode = top.rightNode # (如果有)top的右节点变为当前节点的左节点
    top.rightNode = node # 当前节点变为top的右节点
    return top

右右旋转

如果由于往node.rightNode.rightNode插入新节点而导致不平衡,这时需要对node节点做右旋转。

"""\
    右右旋转
"""
def rotateRR(node):
    top = node.rightNode
    node.rightNode = top.leftNode
    top.rightNode = node
    return top

左右旋转

如果是往node.leftNode.rightNode插入新节点导致不平衡,需要先对node.leftNode做右旋转,再对node自己做左旋转。

"""\
    左右旋转
"""
def rotateLR(node):
    node.leftNode = rotateRR(node.leftNode)
    return rotateLL(node)

右左旋转

如果往node.rightNode.leftNode插入新节点导致不平衡,需要先对node.rightNode做左旋转,再对node自己做右旋转。

"""\
    右左旋转
"""
def rotateRL(node):
    node.rightNode = rotateLL(node.rightNode)
    return rotateRR(node)

插入节点

有了上述四种基础旋转操作,就可以在插入、删除节点时应用上了。插入节点的步骤与BST一直,通过递归地与当前节点比对大小找到最终位置然后插入。不同点在于每次插入之后需要重新计算树高,如果不平衡就通过四种旋转调整到平衡状态。 为了避免每次递归都重新计算左右子树的树高,这里还用到一个小伎俩,就是将当前树高也作为参数递归传递下去,由每层节点累计(当前值+1)。

"""\
    插入新节点
    node: the AVL
    element : the new element
    currentDepth: current node's depth
"""
def insert(node, element, currentDepth):
    # 在最末端新增节点
    if node == None:
        currentDepth = 0
        return AVLNode(element, None, None)
    else:
        leftDepth = 0
        rightDepth = 0

        # 1. 递归添加节点
        # 1.1 如果小于当前节点则遍历左子树
        if node.element > element:
            # 新增节点做为当前的左节点
            node.leftNode = insert(node.leftNode, element, leftDepth)
            return node
        # 1.2 如果大于当前节点则遍历右子树
        elif node.element < element:
            # 新增节点作为当前的右节点
            node.rightNode = insert(node.rightNode, element, rightDepth)
            return node
        # 1.3 如果已经存在了只累计频率
        else:
            node.freqUp()
            return node

        # 2 累计树高
        if leftDepth > rightDepth:
            currentDepth = leftDepth + 1
        elif rightDepth < leftDepth:
            currentDepth = rightDepth + 1
            

        # 3. 处理旋转
        # 3.1 左子树比右子树高出2
        if leftDepth - rightDepth == 2:
            # 3.1.1 在左节点左边
            if element < node.leftNode.element:
                roteteLL(node)
            # 3.1.2 在左节点右边
            elif element > node.leftNode.element:
                rotateLR(node)
        # 3.2 右子树比左子树高出2
        elif rightDepth - leftDepth == 2:
            # 3.2.3 在右节点右边
            if element > node.rightNode.element:
                rotateRR(node)
            # 3.2.4 在右节点左边
            elif element < node.rightNode.element:
                rotateRL(node)

先简单做个测试,随机往AVL插入99个节点,然后中序遍历,看输出的结果是否按从小到大递增。

avl = None

x = 0
while x < 99:
    avl = AVL.insert(avl, random.randint(0, 999), 0)
    x +=1

BinaryTree.inOrder(avl)

""" 屏幕输出: """
6 14 31 38 41 46 57 60 74 75 130 133 135 140 189 196 214 218 244 249 265 271 273 290 298 302 304 337 338 361 371 398 416 428 438 449 457 465 467 471 485 486 502 506 507 513 515 558 561 565 567 574 578 579 585 608 621 647 658 691 693 716 719 724 728 737 748 764 765 770 784 816 838 842 869 881 882 884 891 906 913 924 931 934 939 940 966 971 974 983 986 995 

删除节点

"""\
    删除节点
    TODO: 这段代码还没经过测试,最后一段的判断语句写的有点挫
    node: the AVL
    element: delete element
    currentDepth: current tree depth
"""    
def delete(node, element, currentDepth):
    # 在最末端新增节点
    if node == None:
        currentDepth = 0
        return AVLNode(element, None, None)
    else:
        leftDepth = 0
        rightDepth = 0

        # 1. 递归删除节点
        # 1.1 如果小于当前节点则遍历左子树
        if node.element > element:
            # 新增节点做为当前的左节点
            node.leftNode = delete(node.leftNode, element, leftDepth)
            return node
        # 1.2 如果大于当前节点则遍历右子树
        elif node.element < element:
            # 新增节点作为当前的右节点
            node.rightNode = delete(node.rightNode, element, rightDepth)
            return node
        # 1.3 如果已经存在了先递减频率,另外需要考虑下面还有子节点的情况
        else:
            if node.freq > 0:
                node.freqDown()
                return node
            else:
                # 1.3.1 没有子节点直接删除
                if node.leftNode == None and node.rightNode == None:
                    return delete(node, element, currentDepth)
                # 1.3.2 存在左右子节点需要拿右节点最小值替代当前节点
                elif node.leftNode != None and node.rightNode = None:
                    # TODO
                    pass
                # 1.3.3 存在单个节点直接替换当前节点
                else:
                    # TODO
                    pass

        # 2 累计树高
        if leftDepth < rightDepth:
            currentDepth = leftDepth - 1
        elif rightDepth < leftDepth:
            currentDepth = rightDepth - 1
            

        # 3. 处理旋转(与插入节点的步骤相反,实现不是很干脆,讲究..)
        # 3.1 左子树比右子树高出2
        if abs(leftDepth) - abs(rightDepth) == 2:
            # 3.2.3 在右节点左边
            if node.rightNode != None and ((node.rightNode.leftNode == None and node.rightNode.rightNode != None) or (node.rightNode.leftNode != None and not node.rightNode.rightNode.isLeaf())):
                rotateRL(node)
            # 3.2.4 在右节点左边
            elif node.rightNode != None and ((node.rightNode.rightNode == None and node.rightNode.leftNode != None) or (node.rightNode.rightNode != None and not node.rightNode.leftNode.isLeaf())):
                rotateRR(node)
    
        # 3.2 右子树比左子树高出2
        elif abs(rightDepth) - abs(leftDepth) == 2:
            # 3.1.1 在左节点右边
            if node.leftNode != None and ((node.leftNode.leftNode == None and node.leftNode.rightNode != None) or (node.leftNode.leftNode != None and not node.leftNode.rightNode.isLeaf())):
                roteteLL(node)
            # 3.1.2 在左节点左边
            elif node.leftNode != None and ((node.leftNode.rightNode == None and node.leftNode.leftNode != None) or (node.leftNode.rightNode != None and not node.leftNode.leftNode.isLeaf())):
                rotateLR(node)

性能对比

这里先在本地内存上做下BST和AVL的插入、查询性能对比。分别对BST和AVL顺序插入新节点0-9999,然后按相同的条件进行查询。

bst = None
avl = None

print(datetime.datetime.now())
print('init BST')

i = 0
while i < 19999:
    # bst = BST.insert(bst, random.randint(0, 999999))
    bst = BST.insert(bst, i)
    i += 1

print(datetime.datetime.now())
print('init AVL')

j = 0
while j < 19999:
    # avl = AVL.insert(avl, random.randint(0, 999999), 0)
    #avl = AVL.insert(avl, random.randint(0, 999999), 0)
    avl = AVL.insert(avl, j, 0)
    j += 1

print(datetime.datetime.now())
print('bst search begin')

bstResult = [] * 5000
BST.searchII(bst, 1, 100, bstResult)
BST.searchII(bst, 999, 1800, bstResult)
BST.searchII(bst, 65687, 70000, bstResult)
BST.searchII(bst, 9998, 9999, bstResult)
BST.searchII(bst, 4000, 5000, bstResult)

print(datetime.datetime.now())
print('avl search begin')

avlResult = [] * 5000
BST.searchII(avl, 1, 100, avlResult)
BST.searchII(avl, 999, 1800, avlResult)
BST.searchII(avl, 6558, 7000, avlResult)
BST.searchII(avl, 9998, 9999, avlResult)
BST.searchII(avl, 4000, 5000, avlResult)

print(datetime.datetime.now())
print('end')

""" 屏幕输出 """
begin
2013-03-31 16:08:47.545757
init BST
2013-03-31 16:09:36.126257
init AVL
2013-03-31 16:10:24.857230
bst search begin
2013-03-31 16:10:24.889376
avl search begin
2013-03-31 16:10:24.916285
end

BST插入耗时48.419500秒,AVL插入耗时47.730973秒,BST查询耗时0.32146秒,AVL查询耗时0.26909秒。好吧,数据量太少,差距还不是很明显,下次换到磁盘上试试。