联考day7 C. 树和森林 树形DP

题目描述

样例

样例输入

8 5
BBWWWBBW
1 2
2 3
4 5
6 7
7 8

样例输出

84
2
1 4

样例解释



分析

首先,我们要预处理出一个点到该联通块内所有点的距离之和 \(f\)
这个东西用换根 \(DP\) 搞一下就可以了
那么这个联通块内所有点对的距离之和就是这个联通块所有点的 \(f\) 值之和除以 \(2\)
除以 \(2\) 是因为点对是无序的
对于子任务一:
当联通块的个数为 \(2\) 时,两个联通块内的贡献我们已经考虑了
我们需要考虑的就是跨过联通块的贡献
我们设从联通块 \(1\) 中选择的点为 \(a\),从联通块 \(2\) 中选择的点为 \(b\) ,联通块的大小是 \(cnt\)
那么贡献就是 \((f[a]+cnt[1])*(n-cnt[1])+cnt[1]*f[b]\)
前半部分统计的是 \(1\) 联通块经过 \((a,b)\) 这条边的贡献
后半部分统计的是 \(2\) 联通块对 \(1\) 联通块的贡献
显然,我们需要把两个联通块内 \(f\) 值最大的点连接起来
当联通块的个数为 \(3\) 时,我们可以枚举哪个联通块在中间
设第一个联通块与第二个联通块通过 \((x, y)\) 相连
第二个联通块与第三个联通块通过 \((u, v)\) 相连。
则联通块与联通块之间的贡献为:
\((f[x]+cnt[1])(n−cnt[1])+(f[v]+cnt[3])(n−cnt[3])+cnt[1]f[y]+cnt[3]f[u]+dis(y,u)cnt[1]cnt[3]\)
其中 \(dis\) 代表两点间的距离
那么 \(x\), \(v\) 应该是联通块 \(1\) 和联通块 \(3\)\(f\) 最大的点。
对于联通块 \(2\),我们只要求出 \(cnt[1]f[y]+cnt[3]f[u]+dis(y,u)cnt[1]cnt[3]\) 的最大值即可,这个
可以通过 \(dp\) 实现
我们分别开两个数组存储当前 \(cnt[1]f[y]\)\(cnt[3]f[u]\)的最大值
自底向上 \(dp\)
对于后面的 \(dis\) 值,我们只需要在向上递归是加一个 \(cnt[1]cnt[3]\) 即可
对于子任务二:
考虑一棵树内的所有不满足条件的点。如果有奇数个这样的点,那么无解,否则一定有解,
并且唯一。
我们要使这些点变成合法的,就需要对它们进行两两匹配,然后改变每一对点路径上所有
边的存在情况。
那么,如果一条边两侧的连通块内有奇数个这样的点,这个边的状态就一定被改变了奇数
次,因此它被删掉了;否则它没有被删掉。
总复杂度 \(O(n)\)

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=1e6+5;
int n,m,sl,h[maxn],tot=1,rt1,rt2,rt3,siz[maxn],cnt[maxn],vis[maxn];
struct asd{
	int to,nxt;
}b[maxn];
void ad(int aa,int bb){
	b[tot].to=bb;
	b[tot].nxt=h[aa];
	h[aa]=tot++;
}
char s[maxn];
long long f[maxn],g[maxn];
void dfs(int rt,int now,int fa){
	siz[now]=1;
	vis[now]=rt;
	cnt[rt]++;
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==fa) continue;
		dfs(rt,u,now);
		siz[now]+=siz[u];
		g[now]+=g[u]+siz[u];
	}
}
void dfs2(int rt,int now,int fa){
	if(now==rt)f[now]=g[now];
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==fa) continue;
		f[u]=f[now]+cnt[rt]-siz[u]-siz[u];
		dfs2(rt,u,now);
	}
}
int jl,mmax,A,B,C,D;
void solve1(){
	rg long long ans=0;
	for(rg int i=1;i<=n;i++){
		if(vis[i]==0){
			dfs(i,i,0);
			if(!rt1) rt1=i;
			else rt2=i;
		}
	}
	dfs2(rt1,rt1,0);
	dfs2(rt2,rt2,0);
	for(rg int i=1;i<=n;i++){
		ans+=f[i];
	}
	ans/=2;
	mmax=-1,jl=0;
	for(rg int i=1;i<=n;i++){
		if(vis[i]==rt1){
			if(f[i]>mmax){
				mmax=f[i];
				jl=i;
			}
		}
	}
	A=jl;
	mmax=-1,jl=0;
	for(rg int i=1;i<=n;i++){
		if(vis[i]==rt2){
			if(f[i]>mmax){
				mmax=f[i];
				jl=i;
			}
		}
	}
	B=jl;
	ans+=(f[A]+cnt[rt1])*(n-cnt[rt1])+cnt[rt1]*f[B];
	printf("%lld\n",ans);
}
long long maxb[maxn],maxc[maxn],haha=0;
void dfs5(int now,int fa,int cntl,int cntr){
	maxb[now]=cntl*f[now];
	maxc[now]=cntr*f[now];
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==fa) continue;
		dfs5(u,now,cntl,cntr);
		haha=std::max(haha,maxb[now]+maxc[u]+1LL*cntl*cntr);
		haha=std::max(haha,maxc[now]+maxb[u]+1LL*cntl*cntr);
		maxb[now]=std::max(maxb[u]+1LL*cntl*cntr,maxb[now]);
		maxc[now]=std::max(maxc[u]+1LL*cntl*cntr,maxc[now]);
	}
}
long long js(int l,int mids,int r){
	memset(maxb,0,sizeof(maxb));
	memset(maxc,0,sizeof(maxc));
	rg long long jla=0,jld=0;
	haha=0;
	for(rg int i=1;i<=n;i++){
		if(vis[i]==l){
			if(f[i]>jla) jla=f[i];
		}
		if(vis[i]==r){
			if(f[i]>jld) jld=f[i];
		}
	}
	dfs5(mids,0,cnt[l],cnt[r]);
	haha+=(jla+cnt[l])*(n-cnt[l])+(jld+cnt[r])*(n-cnt[r]);
	return haha;
}
void solve2(){
	rg long long ans=0;
	for(rg int i=1;i<=n;i++){
		if(vis[i]==0){
			dfs(i,i,0);
			if(!rt1) rt1=i;
			else if(!rt2)rt2=i;
			else rt3=i;
		}
	}
	dfs2(rt1,rt1,0);
	dfs2(rt2,rt2,0);
	dfs2(rt3,rt3,0);
	for(rg int i=1;i<=n;i++){
		ans+=f[i];
	}
	ans/=2;
	rg long long nans=0;
	nans=std::max(nans,js(rt1,rt2,rt3));
	nans=std::max(nans,js(rt1,rt3,rt2));
	nans=std::max(nans,js(rt2,rt1,rt3));
	ans+=nans;
	printf("%lld\n",ans);
}
int sta[maxn],tp,du[maxn],num[maxn];
bool kil[maxn];
void dfs3(int now,int fa){
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==fa) continue;
		dfs3(u,now);
		num[now]+=num[u];
	}
}
void dfs4(int rt,int now,int fa){
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==fa) continue;
		dfs4(rt,u,now);
		if((num[rt]-num[u])&1 && num[u]&1){
			kil[i]=1;
		}
	}
}
void solve3(){
	for(rg int i=1;i<=n;i++){
		if(s[i]=='B'){
			if(du[i]%2==0) num[i]=1;
		} else {
			if(du[i]&1) num[i]=1;
		}
	}
	if(sl==2){
		dfs3(rt1,0);
		dfs3(rt2,0);
		if(num[rt1]&1 || num[rt2]&1){
			printf("-1\n");
			return;
		}
		dfs4(rt1,rt1,0);
		dfs4(rt2,rt2,0);
	} else {
		dfs3(rt1,0);
		dfs3(rt2,0);
		dfs3(rt3,0);
		if(num[rt1]&1 || num[rt2]&1 || num[rt3]&1){
			printf("-1\n");
			return;
		}
		dfs4(rt1,rt1,0);
		dfs4(rt2,rt2,0);
		dfs4(rt3,rt3,0);
	}
	for(rg int i=1;i<tot;i+=2){
		if(kil[i] || kil[i+1]) continue;
		sta[++tp]=(i+1)/2;
	}
	printf("%d\n",tp);
	for(rg int i=1;i<=tp;i++){
		printf("%d ",sta[i]);
	}
	printf("\n");
}
int main(){
	freopen("lct.in","r",stdin);
	freopen("lct.out","w",stdout);
	memset(h,-1,sizeof(h));
	n=read(),m=read();
	sl=(n-m);
	scanf("%s",s+1);
	rg int aa,bb;
	for(rg int i=1;i<=m;i++){
		aa=read(),bb=read();
		ad(aa,bb);
		ad(bb,aa);
		du[aa]++;
		du[bb]++;
	}
	if(sl==2){
		solve1();
	} else {
		solve2();
	}
	solve3();
	return 0;
}