To boldly go where no one has gone before.

【学习笔记】替罪羊树

2019-03-17 15:22:49


顾名思义,替罪羊树的思想就是:

如果你儿子建得不好,你就背锅。

所以照这个思维,我们可以定义一个$\alpha \in(0.5,1)$作为一个阈值,一般就$0.7$的样子。也就是说如果当前节点子树大小$s(x)>s(father) \times \alpha$,那么就把这颗树压扁,然后二分重建。

删除操作也属于惰性删除,也就是维护一个当前子树没有被删除的节点的总个数$d(x)$,这样就不用每次都跑去删除然后压扁重建了。

合并信息$(Update)$

注意要把$s(x)$和$d(x)$都更新了,否则惰性操作就会懒出毛病。

inline void Update(int x) {
    s(x) = s(l(x)) + s(r(x)) + c(x);
    d(x) = d(l(x)) + d(r(x)) + c(x);
    return;
}

压扁

判断是否需要被压扁

对于一个节点$x$,有两种情况下它需要背锅被扁:

  1. 左右子树中有一个子树的节点个数大于$s(x)\times \alpha$
  2. 删得太多,惰性操作导致树失衡。

(当然,如果根本就没有这个点,那还是不扁算了)

压扁

压扁的时候,直接按照中序遍历的顺序把子树展开,丢到内存池(数组/$vector$实现)里头去就可以了。

重建

重建的时候,直接二分就行。这里记录的是对$[l,r)$ 建树的鬼畜操作方法。当然也可以改成对$[l,r]$建树,一模一样的。

inline bool NeedRbu(int x) {
    return c(x) && (alpha * s(x) <= (double)max(s(l(x)), s(r(x))) || d(x) <= (double)alpha * s(x));
}

inline void Rbu_Flatten(int &p, int x) {
    if(!x)
        return;
    Rbu_Flatten(p, l(x));
    if(d(x))
        cache[p++] = x;
    Rbu_Flatten(p, r(x));
    return;
}

inline int Rbu_Build(int l, int r) {
    int mid = (l + r) >> 1;
    if(l >= r)
        return 0;
    l(cache[mid]) = Rbu_Build(l, mid);
    r(cache[mid]) = Rbu_Build(mid+1, r);
    Update(cache[mid]);
    return cache[mid];
}

inline void Rbu(int &x) {
    int p = 0;
    Rbu_Flatten(p, x);
    x = Rbu_Build(0, p);
}

插入

和$Treap$一样,不过在更新节点信息之后要看一下需不需要拍扁

inline void Insert(int &x, int v) {
    if(!x) {
        x = ++cnt;
        if(!root)
            root = x;
        v(x) = v;
        l(x) = r(x) = 0;
        c(x) = s(x) = d(x) = 1;
        return;
    }
    if(v(x) == v)
        c(x)++;
    else
        if(v <= v(x))
            Insert(l(x), v);
        else
            Insert(r(x), v);
    Update(x);
    if(NeedRbu(x))
        Rbu(x);
    return;
}

删除

也和$Treap$差不多,惰性操作即可。一样的需要看一下需不需要拍扁。

inline void Delete(int &x, int v) {
    if(!x)
        return;
    d(x)--;
    if(v(x) == v) {
        if(c(x)) 
            c(x)--;
    }   
    else {
        if(v <= v(x))
            Delete(l(x), v);
        else
            Delete(r(x), v);
        Update(x);
    }
    if(NeedRbu(x))
        Rbu(x);
    return;
}

查找第$k$大

和$Treap$基本上一模一样,先查看$(d(l(x)),d(l(x))+c(x)]$区间,然后再往左右子树找即可。

inline int QueryKth(int x, int k) {
    if(!k)
        return 0;
    if(d(l(x)) < k && k <= d(l(x)) + c(x))
        return v(x);
    if(k <= d(l(x)) + c(x))
        return QueryKth(l(x), k);
    else
        return QueryKth(r(x), k-d(l(x))-c(x));
}

查询排名、前缀、后缀

这里有个更高级的方法,就是用$UpperBound()$和$UpperGreater()$。

$UpperBound()$

$UpperBound()$返回的是严格大于$v$的最小元素的最小名次,查询思路简单粗暴。注意:如果根本就没有这个节点,那么当前询问的值就在父节点占据的最后一个位置$+1$处,因此返回$1$即可。

inline int UpperBound(int x, int v) {
    if(!x)
        return 1;
    if(v(x) == v && c(x))
        return d(l(x)) + c(x) + 1;
    else
        if(v < v(x))
            return UpperBound(l(x), v);
        else
            return d(l(x)) + c(x) + UpperBound(r(x), v);
}

$UpperGreater()$

和$UpperBound()$是反义函数,返回严格小于$v$的最大元素的最大名次。写的时候注意:访问左右子树时,判断的条件不变,但是执行的语句反序即可。由于是反义函数,所以在没有这个结点的时候返回$0$即可。

inline int UpperGreater(int x, int v) {
    if(!x)
        return 0;
    if(v(x) == v)
        return d(l(x));
    else
        if(v(x) < v)
            return d(l(x)) + c(x) + UpperGreater(r(x), v);
        else
            return UpperGreater(l(x), v);
}

查询排名

这个就不用单独写一个$QueryRank()$函数了,对于需要查询排名的值$v$,根据对$UpperGreater()$的定义,直接返回$UpperGreater(root,v)+1$即可

printf("%d\n", UpperGreater(root, v) + 1);

查询前缀、后继

就更方便了,由于已经有了$QueryKth()$函数、$UpperBound()$函数以及$UpperGreater()$函数,因此根本不需要再打新函数,一切都简单了。

查找$v$的前缀时,直接

printf("%d\n", QueryKth(root, UpperGreater(root, v)));

查找$v$的后继时,直接

printf("%d\n", QueryKth(root, UpperBound(root, o)));

当然,这是在模板对于前缀和后继的定义下得出的操作。具体情况具体分析。


完整代码

#include<bits/stdc++.h>
#define l(x) tree[x].lson
#define r(x) tree[x].rson
#define v(x) tree[x].val
#define c(x) tree[x].cnt
#define s(x) tree[x].size
#define d(x) tree[x].deleted
using namespace std;

const int MAXN = 1000010;
const int INF = 1 << 30;
const double alpha = 0.7;

struct ScapeGoat {
    int lson, rson, val, cnt, size, deleted;
} tree[MAXN];
int root, cnt = 0;
int cache[MAXN];

inline int read() {
    int res = 0, uz = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-')
            uz = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        res = (res << 3) + (res << 1) + (ch ^ '0');
        ch = getchar();
    }
    return res * uz;
}

inline void Update(int x) {
    s(x) = s(l(x)) + s(r(x)) + c(x);
    d(x) = d(l(x)) + d(r(x)) + c(x);
    return;
}

inline bool NeedRbu(int x) {
    return c(x) && (alpha * s(x) <= (double)max(s(l(x)), s(r(x))) || d(x) <= (double)alpha * s(x));
}

inline void Rbu_Flatten(int &p, int x) {
    if(!x)
        return;
    Rbu_Flatten(p, l(x));
    if(d(x))
        cache[p++] = x;
    Rbu_Flatten(p, r(x));
    return;
}

inline int Rbu_Build(int l, int r) {
    int mid = (l + r) >> 1;
    if(l >= r)
        return 0;
    l(cache[mid]) = Rbu_Build(l, mid);
    r(cache[mid]) = Rbu_Build(mid+1, r);
    Update(cache[mid]);
    return cache[mid];
}

inline void Rbu(int &x) {
    int p = 0;
    Rbu_Flatten(p, x);
    x = Rbu_Build(0, p);
}

//operations

inline void Insert(int &x, int v) {
    if(!x) {
        x = ++cnt;
        if(!root)
            root = x;
        v(x) = v;
        l(x) = r(x) = 0;
        c(x) = s(x) = d(x) = 1;
        return;
    }
    if(v(x) == v)
        c(x)++;
    else
        if(v <= v(x))
            Insert(l(x), v);
        else
            Insert(r(x), v);
    Update(x);
    if(NeedRbu(x))
        Rbu(x);
    return;
}

inline void Delete(int &x, int v) {
    if(!x)
        return;
    d(x)--;
    if(v(x) == v) {
        if(c(x)) 
            c(x)--;
    }   
    else {
        if(v <= v(x))
            Delete(l(x), v);
        else
            Delete(r(x), v);
        Update(x);
    }
    if(NeedRbu(x))
        Rbu(x);
    return;
}

inline int QueryKth(int x, int k) {
    if(!k)
        return 0;
    if(d(l(x)) < k && k <= d(l(x)) + c(x))
        return v(x);
    if(k <= d(l(x)) + c(x))
        return QueryKth(l(x), k);
    else
        return QueryKth(r(x), k-d(l(x))-c(x));
}

inline int UpperBound(int x, int v) {
    if(!x)
        return 1;
    if(v(x) == v && c(x))
        return d(l(x)) + c(x) + 1;
    else
        if(v < v(x))
            return UpperBound(l(x), v);
        else
            return d(l(x)) + c(x) + UpperBound(r(x), v);
}

inline int UpperGreater(int x, int v) {
    if(!x)
        return 0;
    if(v(x) == v)
        return d(l(x));
    else
        if(v(x) < v)
            return d(l(x)) + c(x) + UpperGreater(r(x), v);
        else
            return UpperGreater(l(x), v);
}

int main() {
    int n, opt, o;
    n = read();
    while(n--) {
        opt = read(), o = read();
        switch(opt) {
            case 1: {
                Insert(root, o);
                break;
            }
            case 2: {
                Delete(root, o);
                break;
            }
            case 3: {
                printf("%d\n", UpperGreater(root, o) + 1);
                break;
            }
            case 4: {
                printf("%d\n", QueryKth(root, o));
                break;
            }
            case 5: {
                printf("%d\n", QueryKth(root, UpperGreater(root, o)));
                break;
            }
            default: {
                printf("%d\n", QueryKth(root, UpperBound(root, o)));
                break;
            }
        }
    }
    return 0;
}