【點分治】2019 首爾 icpc Gene Tree

題目

鏈接://ac.nowcoder.com/acm/contest/15644/B
來源:牛客網

A gene tree is a tree showing the evolution of various genes or biological species. A gene tree represents the relatedness of specific genes stored at the leaf nodes without assumption about their ancestry. Leaf nodes represent genes, called taxa, and internal nodes represent putative ancestral taxa. Each edge in the tree is associated with a positive integer, phylogenetic length, which quantifies the evolutionary distance between two nodes of the edge. For example, the left figure below shows a gene tree with six leaf nodes, which approximates the relation among six taxa, and the right one shows a gene tree with four taxa.
Like the trees ܶT1T_1T1 and ܶT2T_2T2 above, gene trees are modeled as unrooted trees where each internal node (non-leaf node) has degree three. A path-length between two leaf nodes is the sum of the phylogenetic lengths of the edges along the unique path between them. In ܶT1T_1T1, the path-length between Human and Cow is 2 + 3 = 5 and the path-length between Human and Goldfish is 2 + 4 + 8 + 10 = 24. These lengths indicate that Human is much closer to Cow than to Goldfish genetically. From ܶT2T_2T2, we can guess that the primate closest to Human is Chimpanzee.

Researchers are interested in measuring the distance between genes in the tree. A famous distance measure is the sum of squared path-lengths of all unordered leaf pairs. More precisely, such a distance ݀d(ܶT) is defined as follows:
d(T)=∑unordered pair(u,v)pu,v2d(T)=\sum_{unordered\,pair(u,v)}p^2_{u,v}d(T)=unorderedpair(u,v)pu,v2
where pu,vp_{u,v}pu,v is a path-length between two leaf nodes u and v in ܶT. Note that ݀d(ܶT) is the sum of the squared path-lengths pu,v2p^2_{u,v}pu,v2 over all unordered leaf pairs u and v in ܶT. For the gene tree ܶT2T_2T2 in Figure B.1, there are six paths over all unordered leaf pairs, (Human, Chimpanzee), (Human, Gorilla), (Human, Orangutan), (Chimpanzee, Gorilla), (Chimpanzee, Orangutan), and (Gorilla, Orangutan). The sum of squared path-lengths is 22+42+52+42+52+52=1112^2 + 4^2 + 5^2 + 4^2 + 5^2 + 5^2 = 11122+42+52+42+52+52=111, so ݀d(ܶT2)d(ܶT_2)d(ܶT2) = 111.

Given an unrooted gene tree T, write a program to output ݀d(T).

 

輸入描述:

Your program is to read from standard input. The input starts with a line containing an integer n (4 ≤ n ≤ 100,000), where n is the number of nodes of the input gene tree ܶT. Then ܶT has n − 1 edges. The nodes of ܶT are numbered from 1 to n. The following n − 1 lines represent n − 1 edges of ܶT, where each line contains three non-negative integers ܽa,b, and ݈l (1 ≤ ܽa ≠ ܾb ≤ n, 1 ≤ ݈l ≤ 50) where two nodes ܽa and ܾb form an edge with phylogenetic length ݈l.

輸出描述:

Your program is to write to standard output. Print exactly one line. The line should contain one positive  integer d(ܶT)


示例1

輸入

 

輸出

 

示例2

輸入

 

輸出

 

示例3

輸入

 

 

輸出

 

題意

給你一個無根樹,求任意兩葉節點路徑和的平方和。

 

題解

正解好像是換根dp,但我因為比賽時昨天看了半小時點分治,一直以為是點分治,當時比賽時點分治學的不行,最後改完bug交完後tle,補題時才知道,點分治是每一個子樹都找一次重心,才能達到nlogn的複雜度。

我不是dp選手所以不懂換根dp怎麼搞,就講講點分治吧。

點分治,實際上是樹上分治演算法,它可以很好的處理樹上路徑問題。它把一顆樹看成根節點與他的子樹,同時它每一個子樹也可以分成一個根節點和子樹。以這個為分治的單位。

樹上的所有路徑按這種分法,實際上就兩種情況:

1.路徑經過根節點。

2.路徑不經過根節點。

就考慮這兩種情況,然後我們一步步分治下去,就可以找到所有答案。

第二種情況由分治來解決,我們就只要處理第一種情況。

兩葉節點的路徑長度可以表示為兩個葉節點到根節點距離的和,所以我們只需要求。數組dis[x]表示節點x到根節點的距離,dfs一遍就可以求出所有的dis,這樣我們利用dis就可以在O(1)的複雜度中求出任意兩葉節點的長度。當然只有這個還是不夠,這樣兩兩匹配複雜度是O(n^2)是數據不能容忍的複雜度。但是我們很容易想到,我們能用組合數學的方法成組的找到答案,如有3個葉節點,a1,a2,a3,任意兩葉節點路徑和的平方和是,a1-a2,a1-a3,a2-a3,這3條路徑的平方和,即(dis[a1]+dis[a2])^2+(dis[a1]+dis[a3])^2+(dis[a2]+dis[a3])^2,顯然,化簡該公式得到,

設dis[ai]=di

2*d1^2+2*d1*(d2+d3)+d2^2+d3^2 +(d2+d3)^2

我們發現先不考慮a2-a3的情況,就從a1出發到其他節點的值為

設n為葉節點個數,sum(i,j)為di到dj的和,ssum(i,j)為di到dj的平方和

(n-1)*d1^2+2*d1*sum(2,n)+ssum(2,n)

其他的路徑,如a2-a3,也可以表示為去掉a1剩下的從a2開始的節點的路徑的平方和

所以這個公式就可以推廣為

(n-1)*d1^2+2*d1*sum(2,n)+ssum(2,n)+(n-2)*d2^2+2*d2*sum(3,n)+ssum(3,n)+…

然後sum和ssum可以使用前綴和維護,這樣我們就可以在O(n)的複雜度中求出任意兩點的平方和

上面我們討論的都是子樹只有單個葉節點的情況,如果子樹有多個葉節點,那我們就會把同子樹的葉節點也算上,但同子樹的葉節點路徑不通過根節點,所以我們需要改動下,最簡單的方法就是單個單個計數,計數時不考慮同子樹的,也容易實現,只要dfs求出bt[X],表是節點X在根節點的哪個子樹,然後使用bt[X]來劃分葉節點就可行,通過一些預處理,也能達到O(n)的複雜度。

但實際上有種更優的方法,

 很容易發現同子樹的連接的節點都是相同的,我們可以從這點優化,

設a1,a2,a3為同子樹的葉節點,m為除去這3節點的剩下節點的個數,sum為剩下節點的和,ssum為剩下節點的平方和,則有

m*d1^2+2*d1*sum+ssum+m*d2^2+2*d2*sum+ssum+m*d3^2+2*d3*sum+ssum’

變形得

m*(d1^2+d2^2+d3^2)+2*(d1+d2+d3)*sum+3*ssum

推廣得

設n為同子樹葉節點個數,sum1為同子樹葉節點和,ssum1為平方和,sum2為剩下節點和,ssum2為平方和

m*ssum1+2*sum1*sum2+n*sum2

這樣就可以成塊的處理節點,並且使用前綴和可以非常方便快速的維護

 

由於點分治每一次遞歸都會重新尋找一次重心,所以每一次分治都會減少一半的大小,所以最終的複雜度是O(nlogn)

程式碼

#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<queue>
#include<cstring>
#include<ctime>
#include<string>
#include<vector>
#include<map>
#include<list>
#include<set>
#include<stack>
#include<bitset>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll> pii;
typedef pair<ll, ll> pll;
const ll N = 1e5 + 5;
const ll mod = 1e9 + 7;
const double gold = (1 + sqrt(5)) / 2.0;
const double PI = acos(-1);
const double eps = 1e-7;
const ll dx[] = { 0,1,0,-1 };
const ll dy[] = { 1,0,-1,0 };
ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a%b); }
ll pow(ll x, ll y, ll mod) { ll ans = 1; while (y) { if (y & 1)ans = (ans* x) % mod; x = (x*x) % mod; y >>= 1; }return ans; }
ll pow(ll x, ll y) { ll ans = 1; while (y) { if (y & 1)ans = (ans* x) % mod; x = (x*x) % mod; y >>= 1; }return ans; }

struct node {
    ll to, w;
    node() {}
    node(ll a, ll b) :to(a), w(b) {}
};
vector<node> e[N];

ll Gsize[N];
ll n;
ll Gans, root;
ll vis[N];

//長鏈縮邊
ll from, Pid;
ll CDSPsum;
ll CDSPcnt;
ll CDSPnum;
void cdsp(ll x, ll f) {

    if (e[x].size() == 2) {
        CDSPcnt++;
        vis[x] = 1;
        if (e[x][0].to != f) {
            CDSPsum += e[x][0].w;
            cdsp(e[x][0].to, x);
        }
        else {
            CDSPsum += e[x][1].w;
            cdsp(e[x][1].to, x);
        }
    }
    else {
        ll a = from, b = Pid, d = CDSPcnt;
        ll c = CDSPsum;
        for (ll i = 0; i < e[x].size(); i++) {
            ll y = e[x][i].to;
            if (y == f) {
                if (CDSPsum&&from&&from != f) {
                    e[a][b].to = x;
                    e[a][b].w = c;
                    e[x][i].to = a;
                    e[x][i].w = c;
                    CDSPnum -= d;
                }
                continue;
            }
            from = x;
            Pid = i;
            CDSPsum = e[x][i].w;
            CDSPcnt = 0;
            cdsp(y, x);
        }
    }


}


//計數
ll tnum;
void getnum(ll x) {
    tnum++;

    for (int i = 0; i < e[x].size(); i++) {
        ll y = e[x][i].to;
        if (vis[y])continue;
        vis[y] = 1;
        getnum(y);
        vis[y] = 0;
    }

}


//找重心

void Gdfs(ll x) {

    Gsize[x] = 1;
    ll mp = 0;
    for (ll i = 0; i < e[x].size(); i++) {
        ll y = e[x][i].to;
        if (vis[y])continue;
        vis[y] = 1;
        Gdfs(y);
        Gsize[x] += Gsize[y];
        if (mp < Gsize[y])
            mp = Gsize[y];
        vis[y] = 0;

    }
    mp = max(mp, tnum - Gsize[x]);
    if (mp < Gans) {
        Gans = mp;
        root = x;
    }

}


ll dis[N];
ll bt[N];
ll leaf[N];
ll llen;
void dfs(ll x) {

    if (e[x].size() == 1 && x != root) {
        leaf[++llen] = x;
    }
    for (ll i = 0; i < e[x].size(); i++) {
        ll y = e[x][i].to;
        if (vis[y])continue;
        if (x != root)bt[y] = bt[x];
        vis[y] = 1;
        dis[y] = dis[x] + e[x][i].w;
        dfs(y);
        vis[y] = 0;
    }
}

ll ans;
ll sf[N], ssf[N];

//點分治
ll L[N], R[N];
ll slen;
void calc(ll x) {

    Gans = 1e9;
    tnum = 0;
    vis[x] = 1;
    getnum(x);
    Gdfs(x);
    vis[x] = 0;
    x = root;
    bt[x] = x;
    for (ll i = 0; i < e[x].size(); i++) {
        bt[e[x][i].to] = e[x][i].to;
    }
    llen = 0;
    //for(ll i=0;i<=n;i++){
    //    dis[i]=0;
    //}
    dis[x] = 0;
    vis[x] = 1;
    dfs(x);
    vis[x] = 0;
    sf[0] = ssf[0] = 0;
    for (ll i = 1; i <= llen; i++) {
        sf[i] = sf[i - 1] + dis[leaf[i]];
        ssf[i] = ssf[i - 1] + dis[leaf[i]] * dis[leaf[i]];
    }


    ll l = 1, r = 1, tip = 0;
    slen = 0;
    for (; r <= llen; r++) {
        if (tip == 0) {
            tip = bt[leaf[r]];
        }
        if (tip != bt[leaf[r + 1]]) {
            L[slen] = l;
            R[slen++] = r;
            tip = 0;
            l = r + 1;
        }

    }

    if (tip) {
        L[slen] = l;
        R[slen++] = r-1;
    }
    for (ll i = 0; i < slen - 1; i++) {
        ll suma = sf[R[i]] - sf[L[i] - 1];
        ll ssuma = ssf[R[i]] - ssf[L[i] - 1];
        ll sumb = sf[R[slen - 1]] - sf[L[i + 1] - 1];
        ll ssumb = ssf[R[slen - 1]] - ssf[L[i + 1] - 1];
        ans += (R[slen - 1] - L[i + 1] + 1)*ssuma + (R[i] - L[i] + 1)*ssumb + 2 * suma*sumb;
    }


    vis[root] = 1;
    for (ll i = 0; i < e[x].size(); i++) {
        ll y = e[x][i].to;
        if (vis[y])continue;
        calc(y);
    }
}



inline ll read() {
    ll s = 0, w = 1;
    char ch = getchar();
    while (ch<'0' || ch>'9') { if (ch == '-')w = -1; ch = getchar(); }
    while (ch >= '0'&&ch <= '9') s = s * 10 + ch - '0', ch = getchar();
    return s * w;
}


int main() {

    scanf("%lld", &n);
    ll a, b, v;
    ll lf;
    for (ll i = 0; i < n - 1; i++) {
        a = read();
        b = read();
        v = read();
        e[a].emplace_back(node(b, v));
        e[b].emplace_back(node(a, v));
    }

    for (ll i = 1; i <= n; i++) {
        if (e[i].size() == 1) {
            lf = i;
            break;
        }
    }
    //這縮邊實際上速度影響不大,快了3ms。

    CDSPnum = n;
    cdsp(lf, 0);
    if (CDSPnum == 2) {
        ans = e[lf][0].w*e[lf][0].w;
    }

    calc(lf);

    printf("%lld\n", ans);

    scanf(" ");
    return 0;
}