To boldly go where no one has gone before.

【学习笔记】树的直径

2019-03-06 23:21:10


这里记录一种由@杨林树♂ 教我的求树的直径的一套方法,解决的问题是P3304

核心思路

直径的定义:一棵树上最长的路径叫做树的直径。可以证明:树上的直径一定是从与树根距离最大的点,到树根,再到与树根距离次大的点。

先用$vector$存边 (链式前向星莫名MLE)。

#define ll long long

struct EDGE {
    int from, to;
    ll dis;
}; //do it with vector!
vector<EDGE> edge[MAXN];

inline void build_edge(int u, int v, ll d) {
    edge[u].push_back((EDGE){u, v, d});
    edge[v].push_back((EDGE){v, u, d});
    return;
}

对于求直径长度的问题,只需用两次$dfs$即可解决。第一次$dfs$找出直径的一个端点$up$,然后对记录的距离信息$memset$。此后,再以找出的这个端点为树根进行第二次$dfs$找到$down$,就把两个端点都找出来了,它们间的距离即直径长度$D$。

inline void init() {
    dfs1(1, 0);
    for(register int i=1; i<=n; i++)
        up = dis[i]>dis[up] ? i : up;
    memset(dis, 0, sizeof(dis));
    dfs1(up, 0);
    for(register int i=1; i<=n; i++)
        down = dis[i]>dis[down] ? i : down;
    D = dis[down];
    return;
}

特殊处理

之后就是解决P3304特色问题了。由于树的直径不一定唯一,所以这道题询问所有直径都经过的边数。

基佬教我的方法是进行第二次$dfs$,具体为查找从当前点$u$出发是否还存在长度为$d$的边。可以证明,设当前点$u$离树的一个端点距离为$d$,假如存在一条不存在于直径中的路径,其长度等于$d$,那么$u$到该端点之间的所有点都不是被所有直径经过的点。 有了这样一条性质,我们只需要先统计出目前发现的直径上的点数$count$,初始化$l=0,r=count$。枚举目前发现的直径上的点$u$,判断:

  1. 当前点$u$到$up$间是否还存在另一条与$u$到$up$距离相等的点,如果存在把$l$更新为自己与$cnt$的较大值,表示从树上第$cnt$个点开始才符合条件;

  2. 当前点$u$到$down$间是否还存在另一条与$u$到$down$距离相等的点,如果存在把$r$更新为自己与$cnt$的较小值,表示到树上第$cnt$个点为止才符合条件。最后直接输出$r-l$就行了。

bool dfs2(int u, int d) {
    if(!d)
        return true;
    //iterator遍历vector:
    for(vector<EDGE>::iterator it = edge[u].begin(); it != edge[u].end(); it++) {
        EDGE o = *it;
        if(o.to == fa[u])
            continue;
        if(dfs2(o.to, d-o.dis))
            return true;
    }
    return false;
}

// 主函数核心代码:
int main() {
    //do something
    int cnt = 0;
    for(int i=down; i!=up; i=fa[i]) 
        son[fa[i]] = i, cnt++;
    int l = 0, r = cnt;
    cnt = 0;
    for(register int i=up; i!=down; i=son[i]) {
        for(vector<EDGE>::iterator it = edge[i].begin(); it!=edge[i].end(); it++) {
            EDGE o = *it;
            if(o.to == son[i] || o.to == fa[i])
                continue;
            if(dfs2(o.to, dis[i]-o.dis))
                l = max(l, cnt);
            if(dfs2(o.to, dis[down]-dis[i]-o.dis))
                r = min(r, cnt);
        }
        cnt++;
    }
}

完整代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAXN = 200010;

struct EDGE {
    int from, to;
    ll dis;
}; //do it with vector!
vector<EDGE> edge[MAXN];
int fa[MAXN], son[MAXN], n, a, b, l;
ll dis[MAXN];

int read() {
    int res = 0, uz = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-')
            uz = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        res = (res<<3) + (res<<1) + (ch^48);
        ch = getchar();
    }
    return res * uz;
}

inline void build_edge(int u, int v, ll d) {
    edge[u].push_back((EDGE){u, v, d});
    edge[v].push_back((EDGE){v, u, d});
    return;
}

inline void dfs1(int u, int father) {
    fa[u] = father;
    for(vector<EDGE>::iterator it = edge[u].begin(); it != edge[u].end(); it++) {
        EDGE o = *it;
        if(o.to == father)
            continue;
        dis[o.to] = dis[u] + o.dis;
        dfs1(o.to, u);
    }
    return;
}

int up, down;
ll D;

inline void init() {
    dfs1(1, 0);
    for(register int i=1; i<=n; i++)
        up = dis[i]>dis[up] ? i : up;
    memset(dis, 0, sizeof(dis));
    dfs1(up, 0);
    for(register int i=1; i<=n; i++)
        down = dis[i]>dis[down] ? i : down;
    D = dis[down];
    return;
}

bool dfs2(int u, int d) {
    if(!d)
        return true;
    for(vector<EDGE>::iterator it = edge[u].begin(); it != edge[u].end(); it++) {
        EDGE o = *it;
        if(o.to == fa[u])
            continue;
        if(dfs2(o.to, d-o.dis))
            return true;
    }
    return false;
}

int main() {
    n = read();
    for(register int i=1; i<n; i++) {
        a = read(); b = read(); l = read();
        build_edge(a, b, (ll)l);
    }
    init();
    int cnt = 0;
    for(int i=down; i!=up; i=fa[i]) 
        son[fa[i]] = i, cnt++;
    int l = 0, r = cnt;
    cnt = 0;
    for(register int i=up; i!=down; i=son[i]) {
        for(vector<EDGE>::iterator it = edge[i].begin(); it!=edge[i].end(); it++) {
            EDGE o = *it;
            if(o.to == son[i] || o.to == fa[i])
                continue;
            if(dfs2(o.to, dis[i]-o.dis))
                l = max(l, cnt);
            if(dfs2(o.to, dis[down]-dis[i]-o.dis))
                r = min(r, cnt);
        }
        cnt++;
    }
    printf("%lli\n%d", D, r-l);
    return 0;
}
//code

written at 2019/03/06/13:52