To boldly go where no one has gone before.

【学习笔记】Treap 平衡树

2019-03-09 01:43:20


没看到写得很透彻的$Treap$题解,作为初学者还是自己写一篇吧。

$Treap$

定义:如果一棵二叉树$T$满足$\forall x\in T$,$x>$ $x$左子树中的所有节点,且$x<$ $x$右子树中的所有节点,则$T$可被称为二叉查找树$BST$ $(Binary$ $Search$ $Tree)$。

顾名思义,$Treap=Tree+Heap$,所以$Treap$平衡树其实质上就是一个把堆套在树上的结构。如果我们需要对一个序列进行插入、删除、求前驱后继、求第$k$大数、求排名等操作,如果只是简单地建立一棵$BST$来跑的话,复杂度和树的高度有关 。其最坏复杂度是$O(n)$,例如输入数据有序的时候,树直接就变成一条链了。所以我们希望可以在维护$BST$性质的情况下,尽可能地把树的高度压缩,这样就可以降低时间复杂度了。

而对于一棵$BST$,有两个操作:右旋$(Zig)$和左旋$(Zag)$,它们不会改变$BST$性质,却可以改变节点间的父子关系。

$Zig$操作

zig.jpg 具体步骤为:

  1. 把蓝色节点接到红色节点的左儿子上面。由于蓝色节点一定小于红色节点而大于绿色节点,因此可以把它接到红色节点的左儿子上面。
  2. 把绿色点的右儿子替换为红色节点。
  3. 把红色节点维护的节点数信息更新到绿色节点上。
  4. 将$Zig$之后的红色节点维护的节点数信息更新,然后把红色节点换掉。
    • 注意,更新节点数信息的时候一定不要更新成$Zig$后变成树根的那个点,而更新的是$Zig$后变成树根右儿子的那个点。这里被坑了。
inline void Zig(int &x) {
    int k = l(x);
    l(x) = r(k), r(k) = x;
    s(k) = s(x);
    update(x); // attention!!
    x = k;
    return;
}

$Zag$操作

zag.jpg$Zag$与$Zig$操作无异,只是$l$与$r$互换。$Zig$和$Zag$互为逆操作。

inline void Zag(int &x) {
    int k = r(x);
    r(x) = l(k), l(k) = x;
    s(k) = s(x);
    update(x); // attention!!
    x = k;
    return;
}

维护方式

我们与$BST$建立一个完全对应的小根堆,每次新建节点的时候就随机为节点分配一个优先级,并且用$Zig$和$Zag$维护小根堆。由于堆的性质,其高度为$log(n)$,因此用这种方法就可以在维持$BST$性质的前提下把$BST$的高度压缩到$log(n)$,因此接下来所有的操作最坏复杂度就全部变成$log(n)$级别了。

插入节点

插入结点的时候有两种情况:

  1. 当前插入位置是空点,则直接在这里新建一个节点,并且初始化节点属性。
  2. 当前插入位置不是空点,先更新当前节点维护的节点数。若插入值和当前节点值相等,直接把$c(x)+1$;若插入值小于当前节点值,则在左子树中插入,并且维护堆属性:若左儿子的优先级小于当前节点的优先级,就$Zig$一下;同样地,若插入值大于当前节点值,则在右子树中插入,并且维护堆属性:若右儿子的优先级小于当前节点的优先级,就$Zag$一下。

这样就完成了对节点的插入。

inline void Insert(int &x, int v) {
    if(!x) {
        x = ++id;
        v(x) = v, c(x) = s(x) = 1;
        l(x) = r(x) = 0;
        p(x) = rand();
        return;
    }
    s(x)++;
    if(v(x) == v)
        c(x)++;
    else {
        if(v < v(x)) {
            Insert(l(x), v);
            if(p(l(x)) < p(x))
                Zig(x);
        } else {
            Insert(r(x), v);
            if(p(r(x)) < p(x))
                Zag(x);
        }
    }
    return;
}

删除节点

删除节点的时候有两种大情况:

  1. 当前节点值等于删除节点值,若当前节点计数$\ge1$则直接$c(x)-1,s(x)-1$即可;若当前节点计数为$1$,则:①如果该点为链节点$($只有一个儿子$)$,直接把它替换为儿子即可;②如果该点有左右儿子,那么就类似于堆的删除,若左儿子优先级小于右儿子则$Zig$,再继续删除;否则$Zag$后再继续删除。
  2. 当前节点不等于删除节点值,那么先更新$s(x)$,若删除节点值小于当前节点值则从左节点继续开始删除;若删除节点值大于当前节点值则从右节点继续开始删除。
    inline void Delete(int &x, int v) {
    if(v(x) == v) {
        if(c(x) > 1)
            c(x)--, s(x)--;
        else
            if(!l(x) || !r(x))
                x = l(x) + r(x);
                // x = l(x) ? l(x) : r(x);
            else
                if(p(l(x)) < p(r(x)))
                    Zig(x), Delete(x, v);
                else
                    Zag(x), Delete(x, v);
        return;
    }
    s(x)--; // attention!! 
    if(v < v(x))
        Delete(l(x), v);
    else
        Delete(r(x), v);
    return;
    }

    查找第$k$大数

    直接从根开始查找。对于查询的序数$k$,可知:如果$k\in(s(l(x)), s(l(x))+c(x)]$,则比$k$小的就只有$s(l(x))$个,因此$v(x)$就是答案。从树根开始查找,只要$x\neq0$,则:

  3. 若$k\in(s(l(x)), s(l(x))+c(x)]$,则直接返回$v(x)$。
  4. 若$k\notin(s(l(x)), s(l(x))+c(x)]$,①$k\le s(l(x))$,则查询的点在左子树内,因此把$x$更新为$l(x)$,再继续查找;②$k>s(l(x))$,则查询的点在左子树内,因此把$k$更新为$k-s(l(x))-c(x)$,把$x$更新为$r(x)$,即在右子树里查找第$k-s(l(x))-c(x)$大节点$($等价于更新之前的查找$)$,然后继续查找。
    inline int QueryKth(int k) {
    int x = root, _k = k;
    while(x) {
        if(_k > s(l(x)) && _k <= s(l(x)) + c(x))
            return v(x);
        if(_k <= s(l(x)))
            x = l(x);
        else
            _k -= s(l(x)) + c(x), x = r(x);
    }
    return 0;
    }

    查找前驱与后继

    这里定义:$x$的前驱指序列中小于$x$的最大数;$x$的后继指序列中大于$x$的最小数。$($有些地方的定义和这个不一样$)$

求前驱与求后继的思路相似,这里以前驱为例: 用$res$维护答案。从根节点开始搜索,只要$x\neq0$,则:

  1. 如果当前节点值小于查找值,则把$res$更新为当前节点值,然后把$x$更新为$r(x)$以寻找右子树中是否还有更优的解。
  2. 如果当前节点值大于查找值,则把$x$更新为$l(x)$以寻找左子树中是否有小于查找值的节点。后继的算法与前驱相似,但略有区别,不再赘述。
inline int QueryPre(int v) {
    int x = root, res = -INF;
    while(x) {
        if(v(x) < v)
            res = v(x), x = r(x);
        else
            x = l(x);
    }
    return res;
}

inline int QuerySuf(int v) {
    int x = root, res = INF;
    while(x) {
        if(v(x) > v)
            res = v(x), x = l(x);
        else
            x = r(x);
    }
    return res;
}

获取排名

最后就是利用$Treap$求查找值在序列中的排名。维护一个当前累计排名$rank$,从根节点开始,只要$x\neq0$,则:

  1. 如果当前节点值等于查找值,直接返回$rank+s(l(x))+1$。
  2. 如果当前节点值小于查找值,则把$x$更新为$l(x)$;如果当前节点值大于查找值,先把$rank$更新为$rank+s(l(x))+c(x)$,表示当前已有$rank+s(l(x))+c(x)$个节点小于查找值。 最后返回$rank$即可。
inline int QueryRank(int v) {
    int x = root, rank = 0;
    while(x) {
        if(v(x) == v)
            return rank + s(l(x)) + 1;
        if(v < v(x))
            x = l(x);
        else
            rank += s(l(x)) + c(x), x = r(x);
    }
    return rank;
}

完整代码:

//code
#include<bits/stdc++.h>
#include<ctime>
#define l(x) tree[x].lson
#define r(x) tree[x].rson
#define v(x) tree[x].val
#define p(x) tree[x].pri
#define c(x) tree[x].cnt
#define s(x) tree[x].size
using namespace std;

const int MAXN = 100010;
const int INF = 1 << 30;

struct TREE {
    int lson, rson, val, pri, cnt, size;
} tree[MAXN];
int id = 0, root = 0;
int n, o, opt;

inline int read() {
    int res = 0, uz = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-')
            uz = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        res = (res << 3) + (res << 1) + (ch ^ '0');
        ch = getchar();
    }
    return res * uz;
}

inline void update(int &x) {
    s(x) = s(l(x)) + s(r(x)) + c(x);
    return;
}

inline void Zig(int &x) {
    int k = l(x);
    l(x) = r(k), r(k) = x;
    s(k) = s(x);
    update(x); // attention!!
    x = k;
    return;
}

inline void Zag(int &x) {
    int k = r(x);
    r(x) = l(k), l(k) = x;
    s(k) = s(x);
    update(x); // attention!!
    x = k;
    return;
}

inline void Insert(int &x, int v) {
    if(!x) {
        x = ++id;
        v(x) = v, c(x) = s(x) = 1;
        l(x) = r(x) = 0;
        p(x) = rand();
        return;
    }
    s(x)++;
    if(v(x) == v)
        c(x)++;
    else {
        if(v < v(x)) {
            Insert(l(x), v);
            if(p(l(x)) < p(x))
                Zig(x);
        } else {
            Insert(r(x), v);
            if(p(r(x)) < p(x))
                Zag(x);
        }
    }
    return;
}

inline void Delete(int &x, int v) {
    if(v(x) == v) {
        if(c(x) > 1)
            c(x)--, s(x)--;
        else
            if(!l(x) || !r(x))
                x = l(x) + r(x);
            else
                if(p(l(x)) < p(r(x)))
                    Zig(x), Delete(x, v);
                else
                    Zag(x), Delete(x, v);
        return;
    }
    s(x)--; // attention!! 
    if(v < v(x))
        Delete(l(x), v);
    else
        Delete(r(x), v);
    return;
}

inline int QueryKth(int k) {
    int x = root, _k = k;
    while(x) {
        if(_k > s(l(x)) && _k <= s(l(x)) + c(x))
            return v(x);
        if(_k <= s(l(x)))
            x = l(x);
        else
            _k -= s(l(x)) + c(x), x = r(x);
    }
    return 0;
}

inline int QueryPre(int v) {
    int x = root, res = -INF;
    while(x) {
        if(v(x) < v)
            res = v(x), x = r(x);
        else
            x = l(x);
    }
    return res;
}

inline int QuerySuf(int v) {
    int x = root, res = INF;
    while(x) {
        if(v(x) > v)
            res = v(x), x = l(x);
        else
            x = r(x);
    }
    return res;
}

inline int QueryRank(int v) {
    int x = root, rank = 0;
    while(x) {
        if(v(x) == v)
            return rank + s(l(x)) + 1;
        if(v < v(x))
            x = l(x);
        else
            rank += s(l(x)) + c(x), x = r(x);
    }
    return rank;
}

int main() {
    srand(time(NULL)); // attention!!
    n = read();
    while(n--) {
        opt = read(), o = read();
        switch(opt) {
            case 1: {
                Insert(root, o);
                break;
            }
            case 2: {
                Delete(root, o);
                break;
            }
            case 3: {
                printf("%d\n", QueryRank(o));
                break;
            }
            case 4: {
                printf("%d\n", QueryKth(o));
                break;
            }
            case 5: {
                printf("%d\n", QueryPre(o));
                break;
            }
            default: {
                printf("%d\n", QuerySuf(o));
                break;
            }
        }
    }
    return 0;
}

written at 2019/03/08/23:51