树上分治

树上分治

点分治 \(O(nlogn)\)

主要解决有关树上路径统计的问题(其中路径的边权可能需要满足一些条件)
1.基本思想:点分治的本质其实是将一棵树拆分成许多棵子树处理,并不断进行。
2.分治点的选择:树的重心
3.点分治

  • 路径的两个端点在同一个子树内
  • 路径的两个端点不在同一个子树内
  • 路径的某个端点是重心

基本模板

#include <bits/stdc++.h>
#define fi first
#define se second
#define ll long long
using namespace std;
const int N=1e4+5;
vector<pair<int,int>>E[N];
int n,k,S; //S记录根据重心划分之后当前遍历的子树的大小
int sz[N],mxson[N];
bool vis[N];//是否是重心
int MX,root,dist[N],cnt;
ll ans;
void getroot(int u,int fa)
{
    sz[u]=1,mxson[u]=0;
    for(auto [v,w]:E[u])
    {
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]=sz[u]+sz[v];
        mxson[u]=max(mxson[u],sz[v]);
    }
    mxson[u]=max(mxson[u],S-sz[u]);
    if(mxson[u]<MX) root=u,MX=mxson[u];
}
void getdist(int u,int fa)
{
    
}
int solve(int u)
{
}
//分治难点在于统计合并答案
void Divide(int u)
{
    solve(u,1); //统计答案
    vis[u]=1;
    for(auto [v,w]:E[u])
    {
        if(vis[v]) continue;
        solve(v,-1); //统计答案(这里可能会用到容斥之类的)
        S=sz[v],root=0,MX=N;
        getroot(v,0);
        Divide(root);
    }
    return;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    while(cin>>n>>k,n&&k)
    {
        for(int i=1;i<=n;i++) vis[i]=0,E[i].clear();
        for(int i=1;i<n;i++)
        {
            int u,v,w;
            cin>>u>>v>>w;
            u++,v++;
            E[u].push_back({v,w});
            E[v].push_back({u,w});
        }
        ans=0;
        S=n,MX=N;
        getroot(1,0);
        Divide(root);
        cout<<ans<<'\n';
    }
    
}

基础例题

1.

给定一个有 \(N\)个点(编号 \(0,1,…,N−1\))的树,每条边都有一个权值(不超过 1000)。
树上两个节点 \(x\)\(y\) 之间的路径长度就是路径上各条边的权值之和。
求长度不超过 \(K\) 的路径有多少条。

思路:分治后在每个子树中暴力求出所有点到分治点的距离,然后将所有距离存在一个数组中排序,利用双指针的方法统计答案,然后对于记重的边用容斥原理去解决。

#include <bits/stdc++.h>
#define fi first
#define se second
#define ll long long
using namespace std;
const int N=1e4+5;
vector<pair<int,int>>E[N];
int n,k,S; //S记录根据重心划分之后当前遍历的子树的大小
int sz[N],mxson[N];
bool vis[N];//是否是重心
int MX,root,dist[N],cnt;
ll ans;
void getroot(int u,int fa)
{
    sz[u]=1,mxson[u]=0;
    for(auto [v,w]:E[u])
    {
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]=sz[u]+sz[v];
        mxson[u]=max(mxson[u],sz[v]);
    }
    mxson[u]=max(mxson[u],S-sz[u]);
    if(mxson[u]<MX) root=u,MX=mxson[u];
}
void getdist(int u,int fa,int d)
{
    dist[++cnt]=d;
    for(auto [v,w]:E[u])
    {
        if(v==fa||vis[v]) continue;
        getdist(v,u,w+d);
    }
    return;
    
}
int solve(int u,int d)
{
    cnt=0;
    memset(dist,0,sizeof dist);
    getdist(u,0,d);
    sort(dist+1,dist+1+cnt);
    int l=1,r=cnt,res=0;
    while(l<=r)  //排序后双指针统计答案
    {
        if(dist[r]+dist[l]<=k) res+=r-l,l++;
        else r--;
    }
    
    return res;
}
void Divide(int u)
{
    ans=ans+solve(u,0);
    vis[u]=1;
    for(auto [v,w]:E[u])
    {
        if(vis[v]) continue;
        ans=ans-solve(v,w); //容斥原理
        S=sz[v],root=0,MX=N;
        getroot(v,0);
        Divide(root);
    }
    return;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    while(cin>>n>>k,n&&k)
    {
        for(int i=1;i<=n;i++) vis[i]=0,E[i].clear();
        for(int i=1;i<n;i++)
        {
            int u,v,w;
            cin>>u>>v>>w;
            u++,v++;
            E[u].push_back({v,w});
            E[v].push_back({u,w});
        }
        ans=0;
        S=n,MX=N;
        getroot(1,0);
        Divide(root);
        cout<<ans<<'\n';
    }
    
}

2.

给定一棵 \(N\)个节点的树,每条边带有一个权值。
求一条简单路径,路径上各条边的权值和等于 \(K\),且路径包含的边的数量最少。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
constexpr int N=2e5+10,M=1e6+10,inf=0x7f7f7f7f;
vector<pair<int,int>>E[N];
int n,k;
int root,S,MX;
int sz[N],mxson[N];
bool vis[N];
int dist[N],num[M];

void getroot(int u,int fa)
{
    sz[u]=1,mxson[u]=0;
    for(auto [v,w]:E[u])
    {
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]=sz[u]+sz[v];
        mxson[u]=max(mxson[u],sz[v]);
    }
    mxson[u]=max(mxson[u],S-sz[u]);
    if(mxson[u]<MX) root=u,MX=mxson[u];
}
vector<pair<int,int>>t;
void getdist(int u,int fa,int K)//求出这个子树中其他的点到这个子树的重心的距离
{

    if(dist[u]<=k) t.emplace_back(dist[u],K);
    for(auto &[v,w]:E[u])
    {
        if(vis[v]||v==fa) continue;
        dist[v]=dist[u]+w;
        getdist(v,u,K+1);
    }
}
int res;
int q[N],tt=-1;
int solve(int u)
{
    num[0]=0;
    for(auto [v,w]:E[u])
    {
        
        if(vis[v]) continue;
        t.clear();
        dist[v]=w;
        getdist(v,u,1);//统计分离重心后的某一棵子树
        for(auto &[d,cnt]:t)
            res=min(res,cnt+num[k-d]);
        
        for(auto &[d,cnt]:t)
        {
            num[d]=min(num[d],cnt);
            q[++tt]=d;
        }
    }
    while(tt>=0) num[q[tt--]]=inf;

}
void Divide(int u)
{
    solve(u);
    vis[u]=1;
    for(auto [v,w]:E[u])
    {
        if(vis[v]) continue;
        S=sz[v],root=0,MX=N;
        getroot(v,0);
        Divide(root);
    }
    return;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    memset(num,0x7f,sizeof num);
    cin>>n>>k;
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        cin>>u>>v>>w;
        ++u,++v;
        E[u].emplace_back(v,w);
        E[v].emplace_back(u,w);
    }
    S=n,MX=N,root=0,res=inf;
    getroot(1,0);
    Divide(root);
    if(res==inf) cout<<-1<<'\n';
    else cout<<res<<'\n';
    return 0;
}

进阶题:

1.Prime Distance On Tree

题意:给定一个\(n\)个节点的数,问树上任选两点间且这两点间的距离为素数的概率是多少

Sol:点分治,每次递归处理时,假设当前的重心是\(root\),求出分裂后的子树的所有点到\(root\)的距离,并统计它们的个数。
\(cnt[d]\)为距离为\(d\)的路径的个数。同时还要求出所有距离相加的个数,因为这些距离是可以组合的,这一步可以利用\(NTT\)快速求出。
多项式形式为\(\sum cnt[d]x^d\),求完后暴力统计所有长度为素数的边的个数。算完后统计答案由于\(u\)\(v\)\(v\)\(u\)会统计两遍,所以答案还要除以2。分治到子树时,由于答案会可能会算重,因为在统计\(root\)的子树时可能会算到在同一个子树内的两点,所以在统计子树时要先利用容斥原理减去算重的,然后再分治算子树。

代码还有问题,不过大概思路就是这样,先插个眼

#include <bits/stdc++.h>
#define ll long long
#define pb push_back
using namespace std;

const int N = 5e6+10;
const int p = 998244353, gg = 3, ig = 332738118, img = 86583718;//1004535809 
const int mod=998244353;
template <typename T> void rd (T &x)
{
    x=0;int f=1;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') x=x*10+(s^48),s=getchar();
    x*=f;
}

ll qpow(ll a, int b)
{
    ll res = 1;
    while(b) {
        if(b & 1) res = 1ll * res * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return res;
}

vector<ll>A(N);
namespace Poly
{
    #define mul(x, y) (1ll * x * y >= mod ? 1ll * x * y % mod : 1ll * x * y)
    #define minus(x, y) (1ll * x - y < 0 ? 1ll * x - y + mod : 1ll * x - y)
    #define plus(x, y) (1ll * x + y >= mod ? 1ll * x + y - mod : 1ll * x + y)
    #define ck(x) (x >= mod ? x - mod : x)//取模运算太慢了

    typedef vector<ll> poly;
    const int G = 3;//根据具体的模数而定,原根可不一定不一样!!!
    //一般模数的原根为 2 3 5 7 10 6
    const int inv_G = qpow(G, mod - 2);
    int RR[N], deer[2][21][N], inv[N];

    void init(const int t) {//预处理出来NTT里需要的w和wn,砍掉了一个log的时间
        for(int p = 1; p <= t; ++ p) {
            int buf1 = qpow(G, (mod - 1) / (1 << p));
            int buf0 = qpow(inv_G, (mod - 1) / (1 << p));
            deer[0][p][0] = deer[1][p][0] = 1;
            for(int i = 1; i < (1 << p); ++ i) {
                deer[0][p][i] = 1ll * deer[0][p][i - 1] * buf0 % mod;//逆
                deer[1][p][i] = 1ll * deer[1][p][i - 1] * buf1 % mod;
            }
        }
        inv[1] = 1;
        for(int i = 2; i <= (1 << t); ++ i)
            inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
    }

    int NTT_init(int n) {//快速数论变换预处理
        int limit = 1, L = 0;
        while(limit < n) limit <<= 1, L ++ ;
        for(int i = 0; i < limit; ++ i)
            RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
        return limit;
    }

    void NTT(poly &A, int type, int limit) {//快速数论变换
        A.resize(limit);
        for(int i = 0; i < limit; ++ i)
            if(i < RR[i])
                swap(A[i], A[RR[i]]);
        for(int mid = 2, j = 1; mid <= limit; mid <<= 1, ++ j) {
            int len = mid >> 1;
            for(int pos = 0; pos < limit; pos += mid) {
                int *wn = deer[type][j];
                for(int i = pos; i < pos + len; ++ i, ++ wn) {
                    int tmp = 1ll * (*wn) * A[i + len] % mod;
                    A[i + len] = ck(A[i] - tmp + mod);
                    A[i] = ck(A[i] + tmp);
                }
            }
        }
        if(type == 0) {
            for(int i = 0; i < limit; ++ i)
                A[i] = 1ll * A[i] * inv[limit] % mod;
        }
    }

    inline poly poly_mul(poly A, poly B) {//多项式乘法
        int deg = A.size() + B.size() - 1;
        int limit = NTT_init(deg);
        poly C(limit);
        NTT(A, 1, limit);
        NTT(B, 1, limit);
        for(int i = 0; i < limit; ++ i)
            C[i] = 1ll * A[i] * B[i] % mod;
        NTT(C, 0, limit);
        C.resize(deg);
        return C;
    }
    //多个多项式相乘CDQ或者利用优先队列启发式合并
    inline poly CDQ(int l,int r)
    {
      if(l==r)
      { 
        return poly{1,A[l]};
      }
      int mid=l+r>>1;
      poly L=CDQ(l,mid);
      poly R=CDQ(mid+1,r);
      return poly_mul(L,R);  
    }
}

using namespace Poly;
constexpr int MAXN=5e4+10;
vector<int>E[MAXN];
int mxson[MAXN],sz[MAXN];
int n,S,MX,root;
bool vis[MAXN],st[MAXN];
int primes[MAXN],primes_cnt;
void get_primes()
{
    for(int i=2;i<=50000;i++)
    {
        if(!st[i]) primes[++primes_cnt]=i;
        for(int j=1;primes[j]*i<=50000;i++)
        {
            st[primes[j]*i]=1;
            if(i%primes[j]==0) break;
        }
    }
}
void getroot(int u,int fa)
{
     sz[u]=1,mxson[u]=0;
     for(auto &v:E[u])
     {
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]=sz[u]+sz[v];
        mxson[u]=max(mxson[u],sz[v]);
     }
     mxson[u]=max(mxson[u],S-sz[u]);
     if(mxson[u]<MX) root=u,MX=mxson[u];
}
int cnt[MAXN],Len;
ll res=0;
void getdist(int u,int fa,int d)
{
    ++cnt[d];
    Len=max(Len,d);
    for(auto &v:E[u])
    {
        if(vis[v]||v==fa) continue;
        getdist(v,u,d+1);
    }
}
void solve(int u,int d,int type)
{
    Len=0;
    ll sum=0;
    getdist(u,0,d);
    poly F(Len+1);
    for(int i=0;i<=Len;i++) F[i]=cnt[i];
    sum-=cnt[1];
    poly G=poly_mul(F,F);
    int limit=G.size();
    for(int i=1;i<=primes_cnt;i++)
        if(primes[i]>limit) break;
        else sum+=G[primes[i]];
    res+=sum/2*type;
    for(int i=0;i<=Len;i++) cnt[i]=0;
    
}
void Divide(int u)
{
    solve(u,0,1);
    vis[u]=1;
    for(auto &v:E[u])
    {
        if(vis[v]) continue;
        solve(v,1,-1);
        S=sz[v],root=0,MX=MAXN;
        getroot(v,0);
        Divide(v);
    }
}
int main()
{
    //freopen("in.in","r",stdin);
    //ios::sync_with_stdio(false);
    //cin.tie(nullptr);
    get_primes();
    init(20);//2^21 = 2,097,152,根据题目数据多项式项数的大小自由调整,注意大小需要跟deer数组同步(21+1=22)
    rd(n);
    for(int i=1;i<n;i++)
    {
        int  u,v;
        rd(u),rd(v);
        E[u].emplace_back(v);
        E[v].emplace_back(u);
    }
    S=n,MX=MAXN,root=0;
    getroot(1,0);
    Divide(root);
    printf("%.1lf\n",res*2.0/(1.0*n*(n-1)));
    
    
    return 0;
}

2.Constructing Ranches$ (2019-ICPC-HongKong)

题意:给定一棵\(n\)个点的树,每个点有一个点权,问有多少条路径满足路径上的点权可以构成一个简单的凸多边形。

Sol:首先这道题显然点分治,然后利用以下结论,维护路径上的和还有最大值。
\(Lemma\)\(n\) 条边,长度为 \(a_1,a_2…a_n(a_i≤a_{i+1})\) ,其能构成一个 面积大于0的凸多边形 ,当且仅当 \(n>2\)\(\sum_{i=1}^{n-1}a_i\gt 2\times a_n\).
在当前子树中,找出重心后求出其他点到重心的距离、最大值后,在合并跨过重心的答案时,考虑利用树状数组,因为如果两两枚举的话容易T,利用树状数组,将距离离散化后当做树状数组的下标,将最大值排序,依次枚举最大值,(也是在枚举边)在树状数组中找到第一个大于等于\(2*mx-d\)的值,然后查询\(query(m)-query(id[d-1])\),其中\(m\)为离散化后的最大下标,这样为什么可行是因为\(id[d-1]\)后的距离一定大于等于\(d\),那么\(id[d-1]\)后面有几个数,说明就有多少条边可以和当前枚举的边配对,因为\(2*mx-d<x\)等价于\(2*mx<d+x\)

还有错误,待查

#include <bits/stdc++.h>
#define ll long long
//using namespace std;
constexpr int N=2e5+10;
int n,S,MX,root;
std::vector<int>E[N];
int sz[N],mxson[N];
ll res;
int w[N];
bool vis[N];
void getroot(int u,int fa)
{
    sz[u]=1,mxson[u]=0;
    for(auto &v:E[u])
    {
       if(vis[v]||v==fa) continue;
       sz[u]+=sz[v];
       mxson[u]=std::max(mxson[u],sz[v]);
    }
    mxson[u]=std::max(mxson[u],S-sz[u]);
    if(mxson[u]<MX) root=u,MX=mxson[u]; 
}
std::vector<std::pair<int,ll>>t;
void getdist(int u,ll d,int mx,int fa)
{

    for(auto &v:E[u])
    {
        if(v==fa||vis[v]) continue;
        t.emplace_back(std::max(mx,w[v]),d+w[v]);
        getdist(v,d+w[v],std::max(mx,w[v]),u);
    }
}
std::vector<ll>dist;
int tr[N];
int lowbit(int x) {return x&-x;};
void add(int p,int x,int m)
{
    for(;p<=m;p+=lowbit(p)) tr[p]+=x;
}
ll query(int p,int m)
{
   ll ans=0;
   for(;p;p-=lowbit(p)) ans+=tr[p];
    return ans;
}
void solve(int u,int type)
{
    t.clear(),dist.clear();
    getdist(u,1ll*w[u],w[u],0);
    
    for(auto &[mx,d]:t) dist.emplace_back(d);
    std::sort(dist.begin(),dist.end());
    dist.erase(unique(dist.begin(),dist.end()),dist.end());
    int m=dist.size();
    std::sort(t.begin(),t.end());
    for(auto &[mx,d]:t)
    {
        int i=std::lower_bound(dist.begin(),dist.end(),2*mx-d)-dist.begin()+1;
        if(i<=m) res+=1ll*type*(query(m,m)-query(i-1,m));
        add(i,1,m); 
    }
    for(auto &[mx,d]:t)
    {
        int i=std::lower_bound(dist.begin(),dist.end(),d)-dist.begin()+1;
        add(i,-1,m);
    }


}
void Divide(int u)
{
    solve(u,1);
    vis[u]=1;
    for(auto &v:E[u])
    {
        if(vis[v]) continue;
        solve(v,-1);//消除子树重复计算的答案
        S=sz[v],root=0,MX=N;
        getroot(v,0);
        Divide(v);
    }
}
int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int T;
    std::cin>>T;
    while(T--)
    {
        int n;
        std::cin>>n;
        for(int i=1;i<=n;i++) std::cin>>(w[i]),E[i].clear(),vis[i]=0;
        for(int i=1;i<n;i++)
        {
            int u,v;
            std::cin>>u>>v;
            E[u].emplace_back(v);
            E[v].emplace_back(u);
        }
        res=0;
        S=n,MX=N,root=0;
        getroot(1,0);
        Divide(root);
        std::cout<<res<<'\n';
    }
}