To boldly go where no one has gone before.

【学习笔记】树状数组

2019-08-08 18:58:45


定义

树状数组就是为前缀数组建立的树形结构,其最朴素的应用是单点修改,区间查询

对于树状数组中的每一个节点$c[i]$,其管辖$[i-lowbit(i)+1,\ i]$这一段区间,其中$lowbit(i)$计算的是$i$的二进制数位中最靠右的$1$所表示的数。例如$6_{(10)}=110_{(2)}$,那么$c[6]$管辖的即是$[5,6]$这个区间。

可以证明一个数$n$的二进制数位中最多只有$log(n)$个$1$,因此树状数组的复杂度在最坏情况下为$O(logn)$,而最优情况下为$O(1)$。由于树状数组的本质是一个特殊的前缀数组,因此空间开销为$n$,这比树状数组的稳定$O(logn)$复杂度、空间开销$4n$都要优秀。同时,树状数组的代码量远远小于线段树,因此在其应用范围内,树状数组不失为一个优秀的选择。

改良

由于树状数组特性,其原生只支持单点修改、区间查询;而面临区间修改、区间查询的情况,树状数组似乎失去了作用。实际上,树状数组也能够进行区间修改、区间查询。

考虑一个差分数组$d[]$,可以知道$a[n]=\sum_{i=1}^n d[i]$。那么:$$s[n]=\sum_{i=1}^n a[i]=\sum_{i=1}^n \sum_{j=1}^i d[j]$$$$=n\cdot d[1]+(n-1)\cdot d[2]+\cdots+2\cdot d[n-1] + 1\cdot d[n].$$

对这个式子进行处理,可以得到:$$\text{原式}=n\cdot(d[1]+d[2]+\cdots+d[n])-(0\cdot d[1]+1\cdot d[2]+\cdots+(n-1)\cdot d[n]).$$ 也就是说,$s[n]$可以拆成两个部分,一个是$n\cdot\sum_{i=1}^nd[i]$,另一个是$-\sum_{i=1}^n(i-1)\cdot d[i]$,这两个部分就可以用两个树状数组分别维护,每次区间修改都对两棵树进行修改,每次区间查询就进行一次运算即可。


完整代码: (P3372)

#include <bits/stdc++.h>
#define lowbit(x) (x & (-x))
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int n, m;
ll c1[MAXN], c2[MAXN], a[MAXN];

void add(ll *c, int x, int v)
{
    while(x <= n)
    {
        c[x] += v;
        x += lowbit(x);
    }
}

ll query(ll *c, int x)
{
    ll res = 0;
    while(x)
    {
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

void edit(int l, int r, ll k)
{
    add(c1, l, k);
    add(c1, r + 1, - k);
    add(c2, l, k * (l - 1));
    add(c2, r + 1, - k * r);
}

ll presum(int x)
{
    return x * query(c1, x) - query(c2, x);
}

ll sum(int l, int r)
{
    return presum(r) - presum(l - 1);
}

void init()
{
    memset(c1, 0, sizeof(c1));
    memset(c2, 0, sizeof(c2));
    for(int i = 1; i <= n; i++)
    {
        add(c1, i, a[i] - a[i - 1]);
        add(c2, i, (i - 1) * (a[i] - a[i - 1]));
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    init();
    int opt, x, y;
    ll k;
    while(m--)
    {
        scanf("%d%d%d", &opt, &x, &y);
        switch(opt)
        {
            case 1:
                scanf("%lli", &k);
                edit(x, y, k);
                break;
            case 2:
                printf("%lli\n", sum(x, y));
                break;
        }
    }
    return 0;
}

二次改良

例题 LOJ 10115

考虑把每次区间加的操作抽象为一对括号$()$,那么每次询问$[l,r]$区间有多少种树时,答案就可以转化为$[1,r]$区间内的左括号数减去$[1,l)$区间内的右括号数,因此直接使用两个树状数组维护左右括号数即可。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e4 + 10;
int c1[MAXN], c2[MAXN];
int n, m;

void add(int *c, int x, int v) {
    for ( ; x <= n; x += x & (-x))
        c[x] += v;
    return ;
}

int query(int *c, int x) {
    int res = 0;
    for ( ; x; x -= x & (-x))
        res += c[x];
    return res;
}

int main() {
    scanf("%d %d", &n, &m);
    for (int k, l, r; m; m--) {
        scanf("%d %d %d", &k, &l, &r);
        if (k == 1) {
            add(c1, l, 1);
            add(c2, r, 1);
        } else
            printf("%d\n", query(c1, r) - query(c2, l - 1));
    }
    return 0;
}

例题 POJ 1990

分析

这道题乍一看是个$O(n^2)$,然而很明显$20000$的数据范围限定了复杂度只能是$O(nlogn)$。如果扫一遍所有奶牛的复杂度是$O(n)$,那么就必须在$O(logn)$时间内完成对一头牛的计算。考虑如何用树状数组做这道题:

首先,一头奶牛要对答案产生贡献,其必须与$v$小于自身的奶牛交流。这也就意味着,如果以$v$为关键字对原序列进行升序排序,那么在$v$的角度就转化为了一个单调性问题:每头奶牛与其之前的奶牛交流就会对答案产生贡献。问题在于如何处理“距离$\times$最大阈值$=$贡献”这个恶心的算式。

由于已经转化为一个单调性问题,所以不用再枚举每头奶牛,当前奶牛$i$能产生的贡献即是$i\times v[i]\times sum(\left|x[i]-x[j]\right|)\ (j<i)$。拆掉绝对值,就把贡献砍成两部分:左边的奶牛和右边的奶牛。

所以这里维护两个树状数组,$c1[]$维护坐标,$c2[]$维护个数;每次查询完左边信息以后,再利用左边信息获得右边信息,最后将当前奶牛加入树状数组当中。具体操作可以看代码。

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int MAXN = 20010;
struct type_cow
{
    int x, v;
    bool operator < (const type_cow &rhs) const
    {
        return v < rhs.v;
    }
} cow[MAXN];
int n, max_x = 0;
ll c1[MAXN], c2[MAXN];
//  c1维护坐标,c2维护个数 

int lowbit(int x)
{
    return x & (-x);
}

void add(ll *c, int pos, ll v)
{
    while(pos <= max_x)
    {
        c[pos] += v;
        pos += lowbit(pos);
    }
}

ll query(ll *c, int pos)
{
    ll res = 0;
    while(pos)
    {
        res += c[pos];
        pos -= lowbit(pos);
    }
    return res;
}

int main()
{
    memset(c1, 0, sizeof(c1));
    memset(c2, 0, sizeof(c2));
    scanf("%d", &n);
    for(int i = 1; i <= n; i++)
    {
        scanf("%d%d", &cow[i].v, &cow[i].x);
        max_x = max(max_x, cow[i].x);
    }
    sort(cow + 1, cow + n + 1);
    ll ans = 0, dis, num;
    for(int i = 1; i <= n; i++)
    {
//      left
        dis = query(c1, cow[i].x);
        num = query(c2, cow[i].x);
        ans += (num * cow[i].x - dis) * cow[i].v;
//      right
        dis = query(c1, max_x) - dis;
        num = (i - 1) - num;
        ans += (dis - num * cow[i].x) * cow[i].v;
        add(c1, cow[i].x, cow[i].x);
        add(c2, cow[i].x, 1LL);
    }
    printf("%lli", ans);
    return 0;
}