To boldly go where no one has gone before.

【学习笔记】LCA

2019-03-04 13:15:42


倍增LCA

预处理

先预处理出$n$以内所有数的$log(i)$的值,以进行常数优化。这里使用递推方法,即:$$log(i) = log(i-1) + (2^{log(i-1)} == i)$$ 意思即是说,如果$2^{log(i-1)}$与$i$相等,则$log(i)=log(i-1)+1$,否则$log(i)=log(i-1)$。注意,这样处理出来的结果是$\biggl\lfloor\log(i)\biggr\rfloor+1$,在调用的时候需要减$1$。

for(int i=1; i<=n; i++)
    lg[i] = lg[i-1] + ((1<<lg[i-1]) == i);

维护

对于每个点$i$维护深度$depth[i]$以及它的第$2^j$级父亲$f[i][j]$。 一个重要的递推式就是:$u$的第$2^i$级父亲等于$u$的第$2^{i-1}$级父亲的第$2^{i-1}$级父亲。

void dfs(int u, int fa) {
    depth[u] = depth[fa] + 1;
    f[u][0] = fa;
    for(int i=1; (1<<i)<=depth[u]; i++)
        f[u][i] = f[f[u][i-1]][i-1];
    for(int i=last[u]; i; i=edge[i].next) {
        if(edge[i].to == fa)
            continue;
        dfs(edge[i].to, u);
    }
    return;
}

核心算法

不妨设$f[u]>f[v]$,先将$u$跳到和$v$相同高度处。在跳的时候,从大向小跳,每次跳$\biggl\lfloor\log(depth[u]-depth[v])\biggr\rfloor$层,直至$u$与$v$来到同一高度。此后,存在两种情况: 1) $u$已经和$v$处在同一节点,即$v$本来为$u$的父节点,则直接返回$u$; 2) $u$和$v$处在不同的节点,那么就需要将$u$于$v$一同向上跳,跳的步数从$2^{log(i)}$一直到$1$,直到跳到它们$LCA$的下一层。此后$u$的父节点即为它们的$LCA$。

    if(x == y)
        return x;
    for(int i=lg[depth[x]]-1; i>=0; i--) 
        if(f[x][i] != f[y][i]) {
            x = f[x][i];
            y = f[y][i];
        }
    return f[x][0];

完整代码如下:

//code
#include<bits/stdc++.h>
using namespace std;

const int MAXN = 500010;

struct EDGE {
    int to, next;
} edge[2*MAXN];
int id = 0, last[MAXN];

void build_edge(int u, int v) {
    edge[++id].to = v;
    edge[id].next = last[u];
    last[u] = id; 
    return;
}

int f[MAXN][30], depth[MAXN];

void dfs(int u, int fa) {
    depth[u] = depth[fa] + 1;
    f[u][0] = fa;
    for(int i=1; (1<<i)<=depth[u]; i++)
        f[u][i] = f[f[u][i-1]][i-1];
    for(int i=last[u]; i; i=edge[i].next) {
        if(edge[i].to == fa)
            continue;
        dfs(edge[i].to, u);
    }
    return;
}

int lg[MAXN];
int n, m, s, a, b;

int lca(int u, int v) {
    int x = u, y = v;
    if(depth[x] < depth[y])
        swap(x, y);
    while(depth[x] > depth[y]) 
        x = f[x][lg[depth[x] - depth[y]] - 1];
    if(x == y)
        return x;
    for(int i=lg[depth[x]]-1; i>=0; i--) 
        if(f[x][i] != f[y][i]) {
            x = f[x][i];
            y = f[y][i];
        }
    return f[x][0];
}

int main() {
    scanf("%d%d%d", &n, &m, &s);
    for(int i=1; i<n; i++) {
        scanf("%d%d", &a, &b);
        build_edge(a, b);
        build_edge(b, a);
    }
    dfs(s, 0);
    for(int i=1; i<=n; i++)
        lg[i] = lg[i-1] + ((1<<lg[i-1]) == i);
    while(m--) {
        scanf("%d%d", &a, &b);
        printf("%d\n", lca(a, b));
    }
    return 0;
}

另一种更快的写法

那就是这个不用预处理$log$的写法了,好像跑得快,思想也很简单。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e5 + 10;
struct type_edge {
    int to, next;
} edge[MAXN << 1];
int head[MAXN], id = 0;
int n, m, s, dep[MAXN], fa[MAXN][30];

void build_edge(int u, int v) {
    edge[++id].to = v;
    edge[id].next = head[u];
    head[u] = id;
}

void dfs(int u, int f) {
    fa[u][0] = f;
    dep[u] = dep[f] + 1;
    for(int i = 0; fa[u][i]; i++)
        fa[u][i + 1] = fa[fa[u][i]][i];
    for(int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if(v == f)
            continue;
        dfs(v, u);
    }
}

int lca(int u, int v) {
    if(dep[u] < dep[v])
        swap(u, v);
    int d = dep[u] - dep[v];
    for(int i = 0; d; i++, d >>= 1)
        if(d & 1)
            u = fa[u][i];
    if(u == v)
        return u;
    for(int i = 20; i >= 0; i--)
        if(fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    return fa[u][0];
}

int main() {
    scanf("%d %d %d", &n, &m, &s);
    for(int i = 1, x, y; i < n; i++) {
        scanf("%d %d", &x, &y);
        build_edge(x, y);
        build_edge(y, x);
    }
    dfs(s, 0);
    int a, b;
    while(m--) {
        scanf("%d %d", &a, &b);
        printf("%d\n", lca(a, b));
    }
    return 0;
}

建模 / $O(logn)$级 LCA

LCA 倒不是非常难,难的是对题目进行建模,转化为可以用 LCA 算法解决的问题。

例题 P1852

观察一下这道题,首先$[-10^9,10^9]$的范围就足够恶心;其次,还要考虑无解的情况。可是如果单纯靠 bfs,空间时间都会爆炸,而且没法判无解。

这个时候就得从问题的特殊性出发,把问题抽象化。题面说每个棋子跳动的规则是:只越过一个另外的棋子,并且和越过棋子的距离保持不变。此外,数轴上每个整点只能放下一个棋子。

继续深入:设这三个棋子用$(a,b,c)$表示,且不妨设$d_1=b-a,d_2=c-b$。对于一般的情况,$d_1\neq d_2$时,有三种跳法:

  1. $b$往左跳;
  2. $b$往右跳;
  3. $a,c$中离$b$较近的往中间跳。

并且,这些跳法都是可逆的

而当$d_1=d_2$时,则无法进行跳法3,只能进行跳法1、跳法2。

不妨把每种状态抽象成一个节点,每次跳看作从一个节点走到另一个节点。如果把跳法1、跳法2看作跳往子节点,跳法3看作跳往父节点,可以发现这个问题的背景是二叉树。如果发现$d_1=d_2$,那么这个节点就是一棵树的根节点;整个数轴直接被转化为二叉树森林。由于只存在三种跳法,因此树与树之间一定是独立的,两个处于不同树上的节点不可能互达。

这样以后,考虑 bfs 时遇到的所有困难都能够解决了。

$[-10^9,10^9]$的范围虽然巨大,但是每种情况的根一定满足$d_1=d_2$;同时,无解的等价条件即根节点不同。再看询问的最少次数,其实就是求树上两点间的距离,由于边权都是$1$,问题化为找 LCA 然后利用深度信息获得答案。

再处理剩余的细节,如何快速找到根节点呢?如果每次都靠模拟一点点跳,肯定是非常慢的;考虑如何对这个过程进行加速:设当前状态表示为$(a,b,c)$,且$d_1<d_2$。在跳的时候,由于不关注棋子的顺序,可以看作这两个棋子一起向右平移$d_1$个单位,直到$d_1\geq d_2$。所以只需要一口气跳$t=\left\lfloor \frac{d_2-1}{d_1}\right\rfloor$次就可以了。这个方法实际上相当于用辗转相除法优化更损相减法,复杂度降至$log$级别。

现在可以在$O(logk)$的复杂度内计算出节点$x$的$k$级父亲,但是当把初始节点$st$和目标节点$ed$跳到同一深度以后,是没法再像倍增一样每次跳$2^i$渐渐逼近 LCA 的,所以这里又用了二分法。复杂度为$O(log^2n)$。

总之,LCA 是不一定必须用倍增优化的,这道题里头就出现了辗转相除法、二分法等妖魔鬼怪。实际上,LCA 的加速只需要保证复杂度降到$log$级别就好,根本不是只能用倍增、树链剖分。以后的思路得发散一点了。

#include <bits/stdc++.h>
using namespace std;
struct state {
    int a[3];
    bool operator == (const state &rhs) const {
        for(int i = 0; i < 3; i++)
            if(a[i] != rhs.a[i])
                return false;
        return true;
    }
} st, ed;
int ans = 0;

state get_root(state x, int &depth) {
//  找到根节点,同时记录深度
    int d1 = x.a[1] - x.a[0], d2 = x.a[2] - x.a[1], t;
    while(d1 != d2) {
        if(d1 < d2) {
            t = (d2 - 1) / d1;
            depth += t;
            x.a[0] += t * d1;
            x.a[1] += t * d1;
        }
        else {
            t = (d1 - 1) / d2;
            depth += t;
            x.a[1] -= t * d2;
            x.a[2] -= t * d2;
        }
        d1 = x.a[1] - x.a[0], d2 = x.a[2] - x.a[1];
    }
    return x;
}

state kth_pre_state(state x, int k) {
//  找到 x 节点的k级父亲
    state res = x;
    int d1 = res.a[1] - res.a[0], d2 = res.a[2] - res.a[1], t;
    if(d1 < d2) {
        t = min(k, (d2 - 1) / d1);
        k -= t;
        res.a[0] += t * d1;
        res.a[1] += t * d1;
    }
    else {
        t = min(k, (d1 - 1) / d2);
        k -= t;
        res.a[1] -= t * d2;
        res.a[2] -= t * d2;
    }
//  熟悉的辗转相除
    if(k)
        return kth_pre_state(res, k);
    else
        return res;
}

int main() {
    int a, b, c, x, y, z;
    scanf("%d%d%d%d%d%d", &a, &b, &c, &x, &y, &z);
    st = (state){{a, b, c}}, ed = (state){{x, y, z}};
    sort(st.a, st.a + 3);
    sort(ed.a, ed.a + 3);
    int depth1 = 0, depth2 = 0;
    state root1 = get_root(st, depth1);
    state root2 = get_root(ed, depth2);
    if(!(root1 == root2)) {
        printf("NO");
        return 0;
    }
//  LCA 算法由此开始
    if(depth1 < depth2) {
        swap(depth1, depth2);
        swap(st, ed);
    }
    ans += depth1 - depth2;
    st = kth_pre_state(st, depth1 - depth2);
    int l = 0, r = depth2;
//  采用二分法枚举
    while(l < r) {
        int mid = (l + r) >> 1;
        state s1 = kth_pre_state(st, mid);
        state s2 = kth_pre_state(ed, mid);
        if(s1 == s2)
            r = mid;
        else
            l = mid + 1;
    }
    printf("YES\n%d", ans + (l << 1));
    return 0;
}

RMQ LCA

思路:每经过一条边就把所到的点记录下来,这样会生成一个长度为$2n-1$的序列。设 dfn[u] 表示第一次经过$u$时对应的 cnt,那么可以证明任意两点$u,v$的 LCA 在序列中的位置一定在 $[$dfn[u]$, $dfn[v]$]$ 中,所以 LCA 的查询就变成了一个静态区间 RMQ 问题,用 ST 表即可实现$O(nlog2n)$预处理,$O(1)$查询。

然而在模板中,这种做法比倍增慢多了

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e5 + 10;
struct sidetable {
    int to, next;
} edge[MAXN << 1];
int head[MAXN], id = 0;
int n, m, k, s;
int dfn[MAXN], cnt = 0;
int st[MAXN << 1][20];
int dep[MAXN];

void build_edge(int u, int v) {
    edge[++id].to = v;
    edge[id].next = head[u];
    head[u] = id;
}

int Min(int u, int v) {
    return dep[u] < dep[v] ? u : v;
}

void dfs(int u, int f) {
    dep[u] = dep[f] + 1;
    st[++cnt][0] = u;
    if(!dfn[u])
        dfn[u] = cnt;
    for(int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if(v == f)
            continue;
        dfs(v, u);
        st[++cnt][0] = u;
    }
}

void prepare() {
    k = (int)(log(cnt) / log(2));
    for(int j = 1; j <= k; j++)
        for(int i = 1; i + (1 << j) - 1 <= cnt; i++)
            st[i][j] = Min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}

int lca(int u, int v) {
    int l = dfn[u], r = dfn[v];
    if(l > r)
        swap(l, r);
    int p = (int)(log(r - l + 1) / log(2));
    return Min(st[l][p], st[r - (1 << p) + 1][p]);
}

int main() {
    scanf("%d %d %d", &n, &m, &s);
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d %d", &u, &v);
        build_edge(u, v);
        build_edge(v, u);
    }
    dfs(s, 0);
    prepare();
    for(int a, b; m; m--) {
        scanf("%d %d", &a, &b);
        printf("%d\n", lca(a, b));
    }
    return 0;
}

特殊距离问题

例题 P4281

给出树上三点$a,b,c$,求一个点$u$使得$u$到$a,b,c$的距离之和最小。

设$a,b,c$两两的 LCA 为$l_1,l_2,l_3$,画图后可以看出需要求的点$u$其实就是深度最大的$l$,而最小距离之和其实是定值$dep[a]+dep[b]+dep[c]-dep[l_1]-dep[l_2]-dep[l_3]$,不需要分类讨论。

其他应用:判相交

例题 P3398

题目是说给出两条路径的端点,判断这两条路径是否有交集。

经过 题解 仔细的思考,发现:如果两条路径相交,必有一条路径的 LCA 在另一条路径上。而一个点$u$在一条路径$(a,b)$上的等值条件是$dis(a,u)+dis(b,u)=dis(a,b)$,所以剩下的事情就用 LCA 处理就好了。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 10;
struct type_edge {
    int to, next;
} edge[MAXN << 1];
int head[MAXN], id = 0;
int n, q, fa[MAXN][30], dep[MAXN];

void build_edge(int u, int v) {
    edge[++id].to = v;
    edge[id].next = head[u];
    head[u] = id;
}

void dfs(int u, int f) {
    dep[u] = dep[f] + 1;
    fa[u][0] = f;
    for(int i = 0; fa[u][i]; i++)
        fa[u][i + 1] = fa[fa[u][i]][i];
    for(int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if(v == f)
            continue;
        dfs(v, u);
    }
}

int LCA(int u, int v) {
    if(dep[u] < dep[v])
        swap(u, v);
    int d = dep[u] - dep[v];
    for(int i = 0; d; i++, d >>= 1)
        if(d & 1)
            u = fa[u][i];
    if(u == v)
        return u;
    for(int i = 20; i >= 0; i--)
        if(fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    return fa[u][0];
}

int dis(int u, int v) {
    int lca = LCA(u, v);
    return dep[u] + dep[v] - (dep[lca] << 1);
}

bool online(int a, int b, int u) {
    return dis(a, u) + dis(b, u) == dis(a, b);
}

int main() {
    scanf("%d %d", &n, &q);
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        build_edge(u, v);
        build_edge(v, u);
    }
    dfs(1, 0);
    while(q--) {
        int a, b, c, d;
        scanf("%d %d %d %d", &a, &b, &c, &d);
        bool flag = false;
        flag = flag | online(a, b, LCA(c, d));
        flag = flag | online(c, d, LCA(a, b));
        if(flag)
            printf("Y\n");
        else
            printf("N\n");
    }
}