树上分治
树上分治
点分治 \(O(nlogn)\)
主要解决有关树上路径统计的问题(其中路径的边权可能需要满足一些条件)
1.基本思想:点分治的本质其实是将一棵树拆分成许多棵子树处理,并不断进行。
2.分治点的选择:树的重心
3.点分治
- 路径的两个端点在同一个子树内
- 路径的两个端点不在同一个子树内
- 路径的某个端点是重心
基本模板
#include <bits/stdc++.h>
#define fi first
#define se second
#define ll long long
using namespace std;
const int N=1e4+5;
vector<pair<int,int>>E[N];
int n,k,S; //S记录根据重心划分之后当前遍历的子树的大小
int sz[N],mxson[N];
bool vis[N];//是否是重心
int MX,root,dist[N],cnt;
ll ans;
void getroot(int u,int fa)
{
sz[u]=1,mxson[u]=0;
for(auto [v,w]:E[u])
{
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]=sz[u]+sz[v];
mxson[u]=max(mxson[u],sz[v]);
}
mxson[u]=max(mxson[u],S-sz[u]);
if(mxson[u]<MX) root=u,MX=mxson[u];
}
void getdist(int u,int fa)
{
}
int solve(int u)
{
}
//分治难点在于统计合并答案
void Divide(int u)
{
solve(u,1); //统计答案
vis[u]=1;
for(auto [v,w]:E[u])
{
if(vis[v]) continue;
solve(v,-1); //统计答案(这里可能会用到容斥之类的)
S=sz[v],root=0,MX=N;
getroot(v,0);
Divide(root);
}
return;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
while(cin>>n>>k,n&&k)
{
for(int i=1;i<=n;i++) vis[i]=0,E[i].clear();
for(int i=1;i<n;i++)
{
int u,v,w;
cin>>u>>v>>w;
u++,v++;
E[u].push_back({v,w});
E[v].push_back({u,w});
}
ans=0;
S=n,MX=N;
getroot(1,0);
Divide(root);
cout<<ans<<'\n';
}
}
基础例题
1.
给定一个有 \(N\)个点(编号 \(0,1,…,N−1\))的树,每条边都有一个权值(不超过 1000)。
树上两个节点 \(x\) 与 \(y\) 之间的路径长度就是路径上各条边的权值之和。
求长度不超过 \(K\) 的路径有多少条。
思路:分治后在每个子树中暴力求出所有点到分治点的距离,然后将所有距离存在一个数组中排序,利用双指针的方法统计答案,然后对于记重的边用容斥原理去解决。
#include <bits/stdc++.h>
#define fi first
#define se second
#define ll long long
using namespace std;
const int N=1e4+5;
vector<pair<int,int>>E[N];
int n,k,S; //S记录根据重心划分之后当前遍历的子树的大小
int sz[N],mxson[N];
bool vis[N];//是否是重心
int MX,root,dist[N],cnt;
ll ans;
void getroot(int u,int fa)
{
sz[u]=1,mxson[u]=0;
for(auto [v,w]:E[u])
{
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]=sz[u]+sz[v];
mxson[u]=max(mxson[u],sz[v]);
}
mxson[u]=max(mxson[u],S-sz[u]);
if(mxson[u]<MX) root=u,MX=mxson[u];
}
void getdist(int u,int fa,int d)
{
dist[++cnt]=d;
for(auto [v,w]:E[u])
{
if(v==fa||vis[v]) continue;
getdist(v,u,w+d);
}
return;
}
int solve(int u,int d)
{
cnt=0;
memset(dist,0,sizeof dist);
getdist(u,0,d);
sort(dist+1,dist+1+cnt);
int l=1,r=cnt,res=0;
while(l<=r) //排序后双指针统计答案
{
if(dist[r]+dist[l]<=k) res+=r-l,l++;
else r--;
}
return res;
}
void Divide(int u)
{
ans=ans+solve(u,0);
vis[u]=1;
for(auto [v,w]:E[u])
{
if(vis[v]) continue;
ans=ans-solve(v,w); //容斥原理
S=sz[v],root=0,MX=N;
getroot(v,0);
Divide(root);
}
return;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
while(cin>>n>>k,n&&k)
{
for(int i=1;i<=n;i++) vis[i]=0,E[i].clear();
for(int i=1;i<n;i++)
{
int u,v,w;
cin>>u>>v>>w;
u++,v++;
E[u].push_back({v,w});
E[v].push_back({u,w});
}
ans=0;
S=n,MX=N;
getroot(1,0);
Divide(root);
cout<<ans<<'\n';
}
}
2.
给定一棵 \(N\)个节点的树,每条边带有一个权值。
求一条简单路径,路径上各条边的权值和等于 \(K\),且路径包含的边的数量最少。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
constexpr int N=2e5+10,M=1e6+10,inf=0x7f7f7f7f;
vector<pair<int,int>>E[N];
int n,k;
int root,S,MX;
int sz[N],mxson[N];
bool vis[N];
int dist[N],num[M];
void getroot(int u,int fa)
{
sz[u]=1,mxson[u]=0;
for(auto [v,w]:E[u])
{
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]=sz[u]+sz[v];
mxson[u]=max(mxson[u],sz[v]);
}
mxson[u]=max(mxson[u],S-sz[u]);
if(mxson[u]<MX) root=u,MX=mxson[u];
}
vector<pair<int,int>>t;
void getdist(int u,int fa,int K)//求出这个子树中其他的点到这个子树的重心的距离
{
if(dist[u]<=k) t.emplace_back(dist[u],K);
for(auto &[v,w]:E[u])
{
if(vis[v]||v==fa) continue;
dist[v]=dist[u]+w;
getdist(v,u,K+1);
}
}
int res;
int q[N],tt=-1;
int solve(int u)
{
num[0]=0;
for(auto [v,w]:E[u])
{
if(vis[v]) continue;
t.clear();
dist[v]=w;
getdist(v,u,1);//统计分离重心后的某一棵子树
for(auto &[d,cnt]:t)
res=min(res,cnt+num[k-d]);
for(auto &[d,cnt]:t)
{
num[d]=min(num[d],cnt);
q[++tt]=d;
}
}
while(tt>=0) num[q[tt--]]=inf;
}
void Divide(int u)
{
solve(u);
vis[u]=1;
for(auto [v,w]:E[u])
{
if(vis[v]) continue;
S=sz[v],root=0,MX=N;
getroot(v,0);
Divide(root);
}
return;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
memset(num,0x7f,sizeof num);
cin>>n>>k;
for(int i=1;i<n;i++)
{
int u,v,w;
cin>>u>>v>>w;
++u,++v;
E[u].emplace_back(v,w);
E[v].emplace_back(u,w);
}
S=n,MX=N,root=0,res=inf;
getroot(1,0);
Divide(root);
if(res==inf) cout<<-1<<'\n';
else cout<<res<<'\n';
return 0;
}
进阶题:
1.Prime Distance On Tree
题意:给定一个\(n\)个节点的数,问树上任选两点间且这两点间的距离为素数的概率是多少
Sol:点分治,每次递归处理时,假设当前的重心是\(root\),求出分裂后的子树的所有点到\(root\)的距离,并统计它们的个数。
记\(cnt[d]\)为距离为\(d\)的路径的个数。同时还要求出所有距离相加的个数,因为这些距离是可以组合的,这一步可以利用\(NTT\)快速求出。
多项式形式为\(\sum cnt[d]x^d\),求完后暴力统计所有长度为素数的边的个数。算完后统计答案由于\(u\)到\(v\)和\(v\)到\(u\)会统计两遍,所以答案还要除以2。分治到子树时,由于答案会可能会算重,因为在统计\(root\)的子树时可能会算到在同一个子树内的两点,所以在统计子树时要先利用容斥原理减去算重的,然后再分治算子树。
代码还有问题,不过大概思路就是这样,先插个眼
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
using namespace std;
const int N = 5e6+10;
const int p = 998244353, gg = 3, ig = 332738118, img = 86583718;//1004535809
const int mod=998244353;
template <typename T> void rd (T &x)
{
x=0;int f=1;char s=getchar();
while(s<'0'||s>'9'){if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') x=x*10+(s^48),s=getchar();
x*=f;
}
ll qpow(ll a, int b)
{
ll res = 1;
while(b) {
if(b & 1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return res;
}
vector<ll>A(N);
namespace Poly
{
#define mul(x, y) (1ll * x * y >= mod ? 1ll * x * y % mod : 1ll * x * y)
#define minus(x, y) (1ll * x - y < 0 ? 1ll * x - y + mod : 1ll * x - y)
#define plus(x, y) (1ll * x + y >= mod ? 1ll * x + y - mod : 1ll * x + y)
#define ck(x) (x >= mod ? x - mod : x)//取模运算太慢了
typedef vector<ll> poly;
const int G = 3;//根据具体的模数而定,原根可不一定不一样!!!
//一般模数的原根为 2 3 5 7 10 6
const int inv_G = qpow(G, mod - 2);
int RR[N], deer[2][21][N], inv[N];
void init(const int t) {//预处理出来NTT里需要的w和wn,砍掉了一个log的时间
for(int p = 1; p <= t; ++ p) {
int buf1 = qpow(G, (mod - 1) / (1 << p));
int buf0 = qpow(inv_G, (mod - 1) / (1 << p));
deer[0][p][0] = deer[1][p][0] = 1;
for(int i = 1; i < (1 << p); ++ i) {
deer[0][p][i] = 1ll * deer[0][p][i - 1] * buf0 % mod;//逆
deer[1][p][i] = 1ll * deer[1][p][i - 1] * buf1 % mod;
}
}
inv[1] = 1;
for(int i = 2; i <= (1 << t); ++ i)
inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
}
int NTT_init(int n) {//快速数论变换预处理
int limit = 1, L = 0;
while(limit < n) limit <<= 1, L ++ ;
for(int i = 0; i < limit; ++ i)
RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
return limit;
}
void NTT(poly &A, int type, int limit) {//快速数论变换
A.resize(limit);
for(int i = 0; i < limit; ++ i)
if(i < RR[i])
swap(A[i], A[RR[i]]);
for(int mid = 2, j = 1; mid <= limit; mid <<= 1, ++ j) {
int len = mid >> 1;
for(int pos = 0; pos < limit; pos += mid) {
int *wn = deer[type][j];
for(int i = pos; i < pos + len; ++ i, ++ wn) {
int tmp = 1ll * (*wn) * A[i + len] % mod;
A[i + len] = ck(A[i] - tmp + mod);
A[i] = ck(A[i] + tmp);
}
}
}
if(type == 0) {
for(int i = 0; i < limit; ++ i)
A[i] = 1ll * A[i] * inv[limit] % mod;
}
}
inline poly poly_mul(poly A, poly B) {//多项式乘法
int deg = A.size() + B.size() - 1;
int limit = NTT_init(deg);
poly C(limit);
NTT(A, 1, limit);
NTT(B, 1, limit);
for(int i = 0; i < limit; ++ i)
C[i] = 1ll * A[i] * B[i] % mod;
NTT(C, 0, limit);
C.resize(deg);
return C;
}
//多个多项式相乘CDQ或者利用优先队列启发式合并
inline poly CDQ(int l,int r)
{
if(l==r)
{
return poly{1,A[l]};
}
int mid=l+r>>1;
poly L=CDQ(l,mid);
poly R=CDQ(mid+1,r);
return poly_mul(L,R);
}
}
using namespace Poly;
constexpr int MAXN=5e4+10;
vector<int>E[MAXN];
int mxson[MAXN],sz[MAXN];
int n,S,MX,root;
bool vis[MAXN],st[MAXN];
int primes[MAXN],primes_cnt;
void get_primes()
{
for(int i=2;i<=50000;i++)
{
if(!st[i]) primes[++primes_cnt]=i;
for(int j=1;primes[j]*i<=50000;i++)
{
st[primes[j]*i]=1;
if(i%primes[j]==0) break;
}
}
}
void getroot(int u,int fa)
{
sz[u]=1,mxson[u]=0;
for(auto &v:E[u])
{
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]=sz[u]+sz[v];
mxson[u]=max(mxson[u],sz[v]);
}
mxson[u]=max(mxson[u],S-sz[u]);
if(mxson[u]<MX) root=u,MX=mxson[u];
}
int cnt[MAXN],Len;
ll res=0;
void getdist(int u,int fa,int d)
{
++cnt[d];
Len=max(Len,d);
for(auto &v:E[u])
{
if(vis[v]||v==fa) continue;
getdist(v,u,d+1);
}
}
void solve(int u,int d,int type)
{
Len=0;
ll sum=0;
getdist(u,0,d);
poly F(Len+1);
for(int i=0;i<=Len;i++) F[i]=cnt[i];
sum-=cnt[1];
poly G=poly_mul(F,F);
int limit=G.size();
for(int i=1;i<=primes_cnt;i++)
if(primes[i]>limit) break;
else sum+=G[primes[i]];
res+=sum/2*type;
for(int i=0;i<=Len;i++) cnt[i]=0;
}
void Divide(int u)
{
solve(u,0,1);
vis[u]=1;
for(auto &v:E[u])
{
if(vis[v]) continue;
solve(v,1,-1);
S=sz[v],root=0,MX=MAXN;
getroot(v,0);
Divide(v);
}
}
int main()
{
//freopen("in.in","r",stdin);
//ios::sync_with_stdio(false);
//cin.tie(nullptr);
get_primes();
init(20);//2^21 = 2,097,152,根据题目数据多项式项数的大小自由调整,注意大小需要跟deer数组同步(21+1=22)
rd(n);
for(int i=1;i<n;i++)
{
int u,v;
rd(u),rd(v);
E[u].emplace_back(v);
E[v].emplace_back(u);
}
S=n,MX=MAXN,root=0;
getroot(1,0);
Divide(root);
printf("%.1lf\n",res*2.0/(1.0*n*(n-1)));
return 0;
}
2.Constructing Ranches$ (2019-ICPC-HongKong)
题意:给定一棵\(n\)个点的树,每个点有一个点权,问有多少条路径满足路径上的点权可以构成一个简单的凸多边形。
Sol:首先这道题显然点分治,然后利用以下结论,维护路径上的和还有最大值。
\(Lemma\):\(n\) 条边,长度为 \(a_1,a_2…a_n(a_i≤a_{i+1})\) ,其能构成一个 面积大于0的凸多边形 ,当且仅当 \(n>2\) 且 \(\sum_{i=1}^{n-1}a_i\gt 2\times a_n\).
在当前子树中,找出重心后求出其他点到重心的距离、最大值后,在合并跨过重心的答案时,考虑利用树状数组,因为如果两两枚举的话容易T,利用树状数组,将距离离散化后当做树状数组的下标,将最大值排序,依次枚举最大值,(也是在枚举边)在树状数组中找到第一个大于等于\(2*mx-d\)的值,然后查询\(query(m)-query(id[d-1])\),其中\(m\)为离散化后的最大下标,这样为什么可行是因为\(id[d-1]\)后的距离一定大于等于\(d\),那么\(id[d-1]\)后面有几个数,说明就有多少条边可以和当前枚举的边配对,因为\(2*mx-d<x\)等价于\(2*mx<d+x\)。
还有错误,待查
#include <bits/stdc++.h>
#define ll long long
//using namespace std;
constexpr int N=2e5+10;
int n,S,MX,root;
std::vector<int>E[N];
int sz[N],mxson[N];
ll res;
int w[N];
bool vis[N];
void getroot(int u,int fa)
{
sz[u]=1,mxson[u]=0;
for(auto &v:E[u])
{
if(vis[v]||v==fa) continue;
sz[u]+=sz[v];
mxson[u]=std::max(mxson[u],sz[v]);
}
mxson[u]=std::max(mxson[u],S-sz[u]);
if(mxson[u]<MX) root=u,MX=mxson[u];
}
std::vector<std::pair<int,ll>>t;
void getdist(int u,ll d,int mx,int fa)
{
for(auto &v:E[u])
{
if(v==fa||vis[v]) continue;
t.emplace_back(std::max(mx,w[v]),d+w[v]);
getdist(v,d+w[v],std::max(mx,w[v]),u);
}
}
std::vector<ll>dist;
int tr[N];
int lowbit(int x) {return x&-x;};
void add(int p,int x,int m)
{
for(;p<=m;p+=lowbit(p)) tr[p]+=x;
}
ll query(int p,int m)
{
ll ans=0;
for(;p;p-=lowbit(p)) ans+=tr[p];
return ans;
}
void solve(int u,int type)
{
t.clear(),dist.clear();
getdist(u,1ll*w[u],w[u],0);
for(auto &[mx,d]:t) dist.emplace_back(d);
std::sort(dist.begin(),dist.end());
dist.erase(unique(dist.begin(),dist.end()),dist.end());
int m=dist.size();
std::sort(t.begin(),t.end());
for(auto &[mx,d]:t)
{
int i=std::lower_bound(dist.begin(),dist.end(),2*mx-d)-dist.begin()+1;
if(i<=m) res+=1ll*type*(query(m,m)-query(i-1,m));
add(i,1,m);
}
for(auto &[mx,d]:t)
{
int i=std::lower_bound(dist.begin(),dist.end(),d)-dist.begin()+1;
add(i,-1,m);
}
}
void Divide(int u)
{
solve(u,1);
vis[u]=1;
for(auto &v:E[u])
{
if(vis[v]) continue;
solve(v,-1);//消除子树重复计算的答案
S=sz[v],root=0,MX=N;
getroot(v,0);
Divide(v);
}
}
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int T;
std::cin>>T;
while(T--)
{
int n;
std::cin>>n;
for(int i=1;i<=n;i++) std::cin>>(w[i]),E[i].clear(),vis[i]=0;
for(int i=1;i<n;i++)
{
int u,v;
std::cin>>u>>v;
E[u].emplace_back(v);
E[v].emplace_back(u);
}
res=0;
S=n,MX=N,root=0;
getroot(1,0);
Divide(root);
std::cout<<res<<'\n';
}
}