【学习笔记】Treap 平衡树

NCC79601

2019-03-09 01:43:20

Solution

没看到写得很透彻的$Treap$题解,作为初学者还是自己写一篇吧。 # $Treap$ 定义:如果一棵二叉树$T$满足$\forall x\in T$,$x>$ $x$左子树中的所有节点,且$x<$ $x$右子树中的所有节点,则$T$可被称为二叉查找树$BST$ $(Binary$ $Search$ $Tree)$。 顾名思义,$Treap=Tree+Heap$,所以$Treap$平衡树其实质上就是一个把堆套在树上的结构。如果我们需要对一个序列进行插入、删除、求前驱后继、求第$k$大数、求排名等操作,如果只是简单地建立一棵$BST$来跑的话,**复杂度和树的高度有关** 。其最坏复杂度是$O(n)$,例如输入数据有序的时候,树直接就变成一条链了。所以我们希望可以在维护$BST$性质的情况下,尽可能地把树的高度压缩,这样就可以降低时间复杂度了。 而对于一棵$BST$,有两个操作:右旋$(Zig)$和左旋$(Zag)$,它们不会改变$BST$性质,却可以改变节点间的父子关系。 # $Zig$操作 ![zig.jpg](https://img.ffis.me/images/2019/03/09/zig.jpg) 具体步骤为: 1. 把蓝色节点接到红色节点的左儿子上面。由于**蓝色节点一定小于红色节点而大于绿色节点**,因此可以把它接到红色节点的左儿子上面。 2. 把绿色点的右儿子替换为红色节点。 3. 把红色节点维护的节点数信息更新到绿色节点上。 4. 将$Zig$之后的红色节点维护的节点数信息更新,然后把红色节点换掉。 * 注意,更新节点数信息的时候**一定不要**更新成$Zig$后变成树根的那个点,而更新的是$Zig$后**变成树根右儿子的那个点**。这里被坑了。 ```cpp inline void Zig(int &x) { int k = l(x); l(x) = r(k), r(k) = x; s(k) = s(x); update(x); // attention!! x = k; return; } ``` # $Zag$操作 ![zag.jpg](https://img.ffis.me/images/2019/03/09/zag.jpg) $Zag$与$Zig$操作无异,只是$l$与$r$互换。$Zig$和$Zag$互为逆操作。 ```cpp inline void Zag(int &x) { int k = r(x); r(x) = l(k), l(k) = x; s(k) = s(x); update(x); // attention!! x = k; return; } ``` # 维护方式 我们与$BST$建立一个完全对应的小根堆,每次新建节点的时候就随机为节点分配一个优先级,并且用$Zig$和$Zag$维护小根堆。由于堆的性质,其高度为$log(n)$,因此用这种方法就可以在维持$BST$性质的前提下把$BST$的高度压缩到$log(n)$,因此接下来所有的操作最坏复杂度就全部变成$log(n)$级别了。 # 插入节点 插入结点的时候有两种情况: 1. 当前插入位置是空点,则直接在这里新建一个节点,并且初始化节点属性。 2. 当前插入位置不是空点,先更新当前节点维护的节点数。若插入值和当前节点值相等,直接把$c(x)+1$;若插入值小于当前节点值,则在左子树中插入,并且维护堆属性:若左儿子的优先级小于当前节点的优先级,就$Zig$一下;同样地,若插入值大于当前节点值,则在右子树中插入,并且维护堆属性:若右儿子的优先级小于当前节点的优先级,就$Zag$一下。 这样就完成了对节点的插入。 ```cpp inline void Insert(int &x, int v) { if(!x) { x = ++id; v(x) = v, c(x) = s(x) = 1; l(x) = r(x) = 0; p(x) = rand(); return; } s(x)++; if(v(x) == v) c(x)++; else { if(v < v(x)) { Insert(l(x), v); if(p(l(x)) < p(x)) Zig(x); } else { Insert(r(x), v); if(p(r(x)) < p(x)) Zag(x); } } return; } ``` # 删除节点 删除节点的时候有两种大情况: 1. 当前节点值等于删除节点值,若当前节点计数$\ge1$则直接$c(x)-1,s(x)-1$即可;若当前节点计数为$1$,则:①如果该点为链节点$($只有一个儿子$)$,直接把它替换为儿子即可;②如果该点有左右儿子,那么就类似于堆的删除,若左儿子优先级小于右儿子则$Zig$,再继续删除;否则$Zag$后再继续删除。 2. 当前节点不等于删除节点值,那么**先更新**$s(x)$,若删除节点值小于当前节点值则从左节点继续开始删除;若删除节点值大于当前节点值则从右节点继续开始删除。 ```cpp inline void Delete(int &x, int v) { if(v(x) == v) { if(c(x) > 1) c(x)--, s(x)--; else if(!l(x) || !r(x)) x = l(x) + r(x); // x = l(x) ? l(x) : r(x); else if(p(l(x)) < p(r(x))) Zig(x), Delete(x, v); else Zag(x), Delete(x, v); return; } s(x)--; // attention!! if(v < v(x)) Delete(l(x), v); else Delete(r(x), v); return; } ``` # 查找第$k$大数 直接从根开始查找。对于查询的序数$k$,可知:如果$k\in(s(l(x)), s(l(x))+c(x)]$,则比$k$小的就只有$s(l(x))$个,因此$v(x)$就是答案。从树根开始查找,只要$x\neq0$,则: 1. 若$k\in(s(l(x)), s(l(x))+c(x)]$,则直接返回$v(x)$。 2. 若$k\notin(s(l(x)), s(l(x))+c(x)]$,①$k\le s(l(x))$,则查询的点在左子树内,因此把$x$更新为$l(x)$,再继续查找;②$k>s(l(x))$,则查询的点在左子树内,因此把$k$更新为$k-s(l(x))-c(x)$,把$x$更新为$r(x)$,即在右子树里查找第$k-s(l(x))-c(x)$大节点$($等价于更新之前的查找$)$,然后继续查找。 ```cpp inline int QueryKth(int k) { int x = root, _k = k; while(x) { if(_k > s(l(x)) && _k <= s(l(x)) + c(x)) return v(x); if(_k <= s(l(x))) x = l(x); else _k -= s(l(x)) + c(x), x = r(x); } return 0; } ``` # 查找前驱与后继 这里定义:$x$的前驱指序列中小于$x$的最大数;$x$的后继指序列中大于$x$的最小数。$($有些地方的定义和这个不一样$)$ 求前驱与求后继的思路相似,这里以前驱为例: 用$res$维护答案。从根节点开始搜索,只要$x\neq0$,则: 1. 如果当前节点值小于查找值,则把$res$更新为当前节点值,然后把$x$更新为$r(x)$以寻找右子树中是否还有更优的解。 2. 如果当前节点值大于查找值,则把$x$更新为$l(x)$以寻找左子树中是否有小于查找值的节点。后继的算法与前驱相似,但略有区别,不再赘述。 ```cpp inline int QueryPre(int v) { int x = root, res = -INF; while(x) { if(v(x) < v) res = v(x), x = r(x); else x = l(x); } return res; } inline int QuerySuf(int v) { int x = root, res = INF; while(x) { if(v(x) > v) res = v(x), x = l(x); else x = r(x); } return res; } ``` # 获取排名 最后就是利用$Treap$求查找值在序列中的排名。维护一个当前累计排名$rank$,从根节点开始,只要$x\neq0$,则: 1. 如果当前节点值等于查找值,直接返回$rank+s(l(x))+1$。 2. 如果当前节点值小于查找值,则把$x$更新为$l(x)$;如果当前节点值大于查找值,先把$rank$更新为$rank+s(l(x))+c(x)$,表示当前已有$rank+s(l(x))+c(x)$个节点小于查找值。 最后返回$rank$即可。 ```cpp inline int QueryRank(int v) { int x = root, rank = 0; while(x) { if(v(x) == v) return rank + s(l(x)) + 1; if(v < v(x)) x = l(x); else rank += s(l(x)) + c(x), x = r(x); } return rank; } ``` --- 完整代码: ```cpp //code #include<bits/stdc++.h> #include<ctime> #define l(x) tree[x].lson #define r(x) tree[x].rson #define v(x) tree[x].val #define p(x) tree[x].pri #define c(x) tree[x].cnt #define s(x) tree[x].size using namespace std; const int MAXN = 100010; const int INF = 1 << 30; struct TREE { int lson, rson, val, pri, cnt, size; } tree[MAXN]; int id = 0, root = 0; int n, o, opt; inline int read() { int res = 0, uz = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') uz = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { res = (res << 3) + (res << 1) + (ch ^ '0'); ch = getchar(); } return res * uz; } inline void update(int &x) { s(x) = s(l(x)) + s(r(x)) + c(x); return; } inline void Zig(int &x) { int k = l(x); l(x) = r(k), r(k) = x; s(k) = s(x); update(x); // attention!! x = k; return; } inline void Zag(int &x) { int k = r(x); r(x) = l(k), l(k) = x; s(k) = s(x); update(x); // attention!! x = k; return; } inline void Insert(int &x, int v) { if(!x) { x = ++id; v(x) = v, c(x) = s(x) = 1; l(x) = r(x) = 0; p(x) = rand(); return; } s(x)++; if(v(x) == v) c(x)++; else { if(v < v(x)) { Insert(l(x), v); if(p(l(x)) < p(x)) Zig(x); } else { Insert(r(x), v); if(p(r(x)) < p(x)) Zag(x); } } return; } inline void Delete(int &x, int v) { if(v(x) == v) { if(c(x) > 1) c(x)--, s(x)--; else if(!l(x) || !r(x)) x = l(x) + r(x); else if(p(l(x)) < p(r(x))) Zig(x), Delete(x, v); else Zag(x), Delete(x, v); return; } s(x)--; // attention!! if(v < v(x)) Delete(l(x), v); else Delete(r(x), v); return; } inline int QueryKth(int k) { int x = root, _k = k; while(x) { if(_k > s(l(x)) && _k <= s(l(x)) + c(x)) return v(x); if(_k <= s(l(x))) x = l(x); else _k -= s(l(x)) + c(x), x = r(x); } return 0; } inline int QueryPre(int v) { int x = root, res = -INF; while(x) { if(v(x) < v) res = v(x), x = r(x); else x = l(x); } return res; } inline int QuerySuf(int v) { int x = root, res = INF; while(x) { if(v(x) > v) res = v(x), x = l(x); else x = r(x); } return res; } inline int QueryRank(int v) { int x = root, rank = 0; while(x) { if(v(x) == v) return rank + s(l(x)) + 1; if(v < v(x)) x = l(x); else rank += s(l(x)) + c(x), x = r(x); } return rank; } int main() { srand(time(NULL)); // attention!! n = read(); while(n--) { opt = read(), o = read(); switch(opt) { case 1: { Insert(root, o); break; } case 2: { Delete(root, o); break; } case 3: { printf("%d\n", QueryRank(o)); break; } case 4: { printf("%d\n", QueryKth(o)); break; } case 5: { printf("%d\n", QueryPre(o)); break; } default: { printf("%d\n", QuerySuf(o)); break; } } } return 0; } ``` written at 2019/03/08/23:51