【学习笔记】Treap 平衡树
NCC79601
2019-03-09 01:43:20
没看到写得很透彻的$Treap$题解,作为初学者还是自己写一篇吧。
# $Treap$
定义:如果一棵二叉树$T$满足$\forall x\in T$,$x>$ $x$左子树中的所有节点,且$x<$ $x$右子树中的所有节点,则$T$可被称为二叉查找树$BST$ $(Binary$ $Search$ $Tree)$。
顾名思义,$Treap=Tree+Heap$,所以$Treap$平衡树其实质上就是一个把堆套在树上的结构。如果我们需要对一个序列进行插入、删除、求前驱后继、求第$k$大数、求排名等操作,如果只是简单地建立一棵$BST$来跑的话,**复杂度和树的高度有关** 。其最坏复杂度是$O(n)$,例如输入数据有序的时候,树直接就变成一条链了。所以我们希望可以在维护$BST$性质的情况下,尽可能地把树的高度压缩,这样就可以降低时间复杂度了。
而对于一棵$BST$,有两个操作:右旋$(Zig)$和左旋$(Zag)$,它们不会改变$BST$性质,却可以改变节点间的父子关系。
# $Zig$操作
![zig.jpg](https://img.ffis.me/images/2019/03/09/zig.jpg)
具体步骤为:
1. 把蓝色节点接到红色节点的左儿子上面。由于**蓝色节点一定小于红色节点而大于绿色节点**,因此可以把它接到红色节点的左儿子上面。
2. 把绿色点的右儿子替换为红色节点。
3. 把红色节点维护的节点数信息更新到绿色节点上。
4. 将$Zig$之后的红色节点维护的节点数信息更新,然后把红色节点换掉。
* 注意,更新节点数信息的时候**一定不要**更新成$Zig$后变成树根的那个点,而更新的是$Zig$后**变成树根右儿子的那个点**。这里被坑了。
```cpp
inline void Zig(int &x) {
int k = l(x);
l(x) = r(k), r(k) = x;
s(k) = s(x);
update(x); // attention!!
x = k;
return;
}
```
# $Zag$操作
![zag.jpg](https://img.ffis.me/images/2019/03/09/zag.jpg)
$Zag$与$Zig$操作无异,只是$l$与$r$互换。$Zig$和$Zag$互为逆操作。
```cpp
inline void Zag(int &x) {
int k = r(x);
r(x) = l(k), l(k) = x;
s(k) = s(x);
update(x); // attention!!
x = k;
return;
}
```
# 维护方式
我们与$BST$建立一个完全对应的小根堆,每次新建节点的时候就随机为节点分配一个优先级,并且用$Zig$和$Zag$维护小根堆。由于堆的性质,其高度为$log(n)$,因此用这种方法就可以在维持$BST$性质的前提下把$BST$的高度压缩到$log(n)$,因此接下来所有的操作最坏复杂度就全部变成$log(n)$级别了。
# 插入节点
插入结点的时候有两种情况:
1. 当前插入位置是空点,则直接在这里新建一个节点,并且初始化节点属性。
2. 当前插入位置不是空点,先更新当前节点维护的节点数。若插入值和当前节点值相等,直接把$c(x)+1$;若插入值小于当前节点值,则在左子树中插入,并且维护堆属性:若左儿子的优先级小于当前节点的优先级,就$Zig$一下;同样地,若插入值大于当前节点值,则在右子树中插入,并且维护堆属性:若右儿子的优先级小于当前节点的优先级,就$Zag$一下。
这样就完成了对节点的插入。
```cpp
inline void Insert(int &x, int v) {
if(!x) {
x = ++id;
v(x) = v, c(x) = s(x) = 1;
l(x) = r(x) = 0;
p(x) = rand();
return;
}
s(x)++;
if(v(x) == v)
c(x)++;
else {
if(v < v(x)) {
Insert(l(x), v);
if(p(l(x)) < p(x))
Zig(x);
} else {
Insert(r(x), v);
if(p(r(x)) < p(x))
Zag(x);
}
}
return;
}
```
# 删除节点
删除节点的时候有两种大情况:
1. 当前节点值等于删除节点值,若当前节点计数$\ge1$则直接$c(x)-1,s(x)-1$即可;若当前节点计数为$1$,则:①如果该点为链节点$($只有一个儿子$)$,直接把它替换为儿子即可;②如果该点有左右儿子,那么就类似于堆的删除,若左儿子优先级小于右儿子则$Zig$,再继续删除;否则$Zag$后再继续删除。
2. 当前节点不等于删除节点值,那么**先更新**$s(x)$,若删除节点值小于当前节点值则从左节点继续开始删除;若删除节点值大于当前节点值则从右节点继续开始删除。
```cpp
inline void Delete(int &x, int v) {
if(v(x) == v) {
if(c(x) > 1)
c(x)--, s(x)--;
else
if(!l(x) || !r(x))
x = l(x) + r(x);
// x = l(x) ? l(x) : r(x);
else
if(p(l(x)) < p(r(x)))
Zig(x), Delete(x, v);
else
Zag(x), Delete(x, v);
return;
}
s(x)--; // attention!!
if(v < v(x))
Delete(l(x), v);
else
Delete(r(x), v);
return;
}
```
# 查找第$k$大数
直接从根开始查找。对于查询的序数$k$,可知:如果$k\in(s(l(x)), s(l(x))+c(x)]$,则比$k$小的就只有$s(l(x))$个,因此$v(x)$就是答案。从树根开始查找,只要$x\neq0$,则:
1. 若$k\in(s(l(x)), s(l(x))+c(x)]$,则直接返回$v(x)$。
2. 若$k\notin(s(l(x)), s(l(x))+c(x)]$,①$k\le s(l(x))$,则查询的点在左子树内,因此把$x$更新为$l(x)$,再继续查找;②$k>s(l(x))$,则查询的点在左子树内,因此把$k$更新为$k-s(l(x))-c(x)$,把$x$更新为$r(x)$,即在右子树里查找第$k-s(l(x))-c(x)$大节点$($等价于更新之前的查找$)$,然后继续查找。
```cpp
inline int QueryKth(int k) {
int x = root, _k = k;
while(x) {
if(_k > s(l(x)) && _k <= s(l(x)) + c(x))
return v(x);
if(_k <= s(l(x)))
x = l(x);
else
_k -= s(l(x)) + c(x), x = r(x);
}
return 0;
}
```
# 查找前驱与后继
这里定义:$x$的前驱指序列中小于$x$的最大数;$x$的后继指序列中大于$x$的最小数。$($有些地方的定义和这个不一样$)$
求前驱与求后继的思路相似,这里以前驱为例:
用$res$维护答案。从根节点开始搜索,只要$x\neq0$,则:
1. 如果当前节点值小于查找值,则把$res$更新为当前节点值,然后把$x$更新为$r(x)$以寻找右子树中是否还有更优的解。
2. 如果当前节点值大于查找值,则把$x$更新为$l(x)$以寻找左子树中是否有小于查找值的节点。后继的算法与前驱相似,但略有区别,不再赘述。
```cpp
inline int QueryPre(int v) {
int x = root, res = -INF;
while(x) {
if(v(x) < v)
res = v(x), x = r(x);
else
x = l(x);
}
return res;
}
inline int QuerySuf(int v) {
int x = root, res = INF;
while(x) {
if(v(x) > v)
res = v(x), x = l(x);
else
x = r(x);
}
return res;
}
```
# 获取排名
最后就是利用$Treap$求查找值在序列中的排名。维护一个当前累计排名$rank$,从根节点开始,只要$x\neq0$,则:
1. 如果当前节点值等于查找值,直接返回$rank+s(l(x))+1$。
2. 如果当前节点值小于查找值,则把$x$更新为$l(x)$;如果当前节点值大于查找值,先把$rank$更新为$rank+s(l(x))+c(x)$,表示当前已有$rank+s(l(x))+c(x)$个节点小于查找值。
最后返回$rank$即可。
```cpp
inline int QueryRank(int v) {
int x = root, rank = 0;
while(x) {
if(v(x) == v)
return rank + s(l(x)) + 1;
if(v < v(x))
x = l(x);
else
rank += s(l(x)) + c(x), x = r(x);
}
return rank;
}
```
---
完整代码:
```cpp
//code
#include<bits/stdc++.h>
#include<ctime>
#define l(x) tree[x].lson
#define r(x) tree[x].rson
#define v(x) tree[x].val
#define p(x) tree[x].pri
#define c(x) tree[x].cnt
#define s(x) tree[x].size
using namespace std;
const int MAXN = 100010;
const int INF = 1 << 30;
struct TREE {
int lson, rson, val, pri, cnt, size;
} tree[MAXN];
int id = 0, root = 0;
int n, o, opt;
inline int read() {
int res = 0, uz = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-')
uz = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
res = (res << 3) + (res << 1) + (ch ^ '0');
ch = getchar();
}
return res * uz;
}
inline void update(int &x) {
s(x) = s(l(x)) + s(r(x)) + c(x);
return;
}
inline void Zig(int &x) {
int k = l(x);
l(x) = r(k), r(k) = x;
s(k) = s(x);
update(x); // attention!!
x = k;
return;
}
inline void Zag(int &x) {
int k = r(x);
r(x) = l(k), l(k) = x;
s(k) = s(x);
update(x); // attention!!
x = k;
return;
}
inline void Insert(int &x, int v) {
if(!x) {
x = ++id;
v(x) = v, c(x) = s(x) = 1;
l(x) = r(x) = 0;
p(x) = rand();
return;
}
s(x)++;
if(v(x) == v)
c(x)++;
else {
if(v < v(x)) {
Insert(l(x), v);
if(p(l(x)) < p(x))
Zig(x);
} else {
Insert(r(x), v);
if(p(r(x)) < p(x))
Zag(x);
}
}
return;
}
inline void Delete(int &x, int v) {
if(v(x) == v) {
if(c(x) > 1)
c(x)--, s(x)--;
else
if(!l(x) || !r(x))
x = l(x) + r(x);
else
if(p(l(x)) < p(r(x)))
Zig(x), Delete(x, v);
else
Zag(x), Delete(x, v);
return;
}
s(x)--; // attention!!
if(v < v(x))
Delete(l(x), v);
else
Delete(r(x), v);
return;
}
inline int QueryKth(int k) {
int x = root, _k = k;
while(x) {
if(_k > s(l(x)) && _k <= s(l(x)) + c(x))
return v(x);
if(_k <= s(l(x)))
x = l(x);
else
_k -= s(l(x)) + c(x), x = r(x);
}
return 0;
}
inline int QueryPre(int v) {
int x = root, res = -INF;
while(x) {
if(v(x) < v)
res = v(x), x = r(x);
else
x = l(x);
}
return res;
}
inline int QuerySuf(int v) {
int x = root, res = INF;
while(x) {
if(v(x) > v)
res = v(x), x = l(x);
else
x = r(x);
}
return res;
}
inline int QueryRank(int v) {
int x = root, rank = 0;
while(x) {
if(v(x) == v)
return rank + s(l(x)) + 1;
if(v < v(x))
x = l(x);
else
rank += s(l(x)) + c(x), x = r(x);
}
return rank;
}
int main() {
srand(time(NULL)); // attention!!
n = read();
while(n--) {
opt = read(), o = read();
switch(opt) {
case 1: {
Insert(root, o);
break;
}
case 2: {
Delete(root, o);
break;
}
case 3: {
printf("%d\n", QueryRank(o));
break;
}
case 4: {
printf("%d\n", QueryKth(o));
break;
}
case 5: {
printf("%d\n", QueryPre(o));
break;
}
default: {
printf("%d\n", QuerySuf(o));
break;
}
}
}
return 0;
}
```
written at 2019/03/08/23:51