虛樹學習筆記

作用

虛樹常常被使用在樹形 \(dp\)中。

有些時候,我們需要計算的節點僅僅是一棵樹中的某幾個節點

這個時候如果對整棵樹都進行一次計算開銷太大了

所以我們需要把這些節點從原樹中抽象出來

按照它們在原樹中的關係重新建一棵樹,這樣的樹就是虛樹

構建方法

在構建之前,我們需要把所有需要加入的節點按照 \(dfn\) 序從小到大排好序

在加點時,我們要用棧維護一個最右鏈

在這個鏈左邊的虛樹都已經構建完成

我們設 \(top\) 為棧頂,設要加入的節點為 \(now\),設棧頂元素與 \(now\)\(LCA\)\(lc\)

在加入的時候,會有以下幾種情況

\(1\)\(lc=sta[top]\)

此時我們直接把 \(now\) 接在最右鏈之後即可
\(2\)\(lc\) 位於 \(sta[top]\)\(sta[top-1]\)之間

此時 \(sta[tp]\) 已經不在最右鏈上,將其在虛樹上和 \(lc\) 連邊後出棧

同時把 \(lc\)\(now\) 依次入棧

\(3\)\(lc\)\(sta[top-1]\)

和上面幾乎一樣,只是不把 \(lc\) 入棧

\(4\)\(lc\) 的深度比 \(sta[top-1]\) 還小

我們把 \(sta[top]\)\(sta[top-1]\) 連邊後出棧,重複之前的操作

這樣,我們直接在建出來的虛樹上 \(dp\) 就可以了

設總點數為 \(k\),則時間複雜度為 \(O(klogk)\)

程式碼實現

P2495 [SDOI2011]消耗戰為例

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=1e6+5;
int h[maxn],tot=1,h2[maxn],t2=1;
struct asd{
	int to,nxt,val;
}b[maxn],b2[maxn];
void ad(rg int aa,rg int bb,rg int cc){
	b[tot].to=bb;
	b[tot].val=cc;
	b[tot].nxt=h[aa];
	h[aa]=tot++;
}
void ad2(rg int aa,rg int bb){
	b2[t2].to=bb;
	b2[t2].nxt=h2[aa];
	h2[aa]=t2++;
}
int n,m,fa[maxn],dep[maxn],son[maxn],siz[maxn];
long long mindis[maxn];
void dfs1(rg int now,rg int lat){
	fa[now]=lat;
	dep[now]=dep[lat]+1;
	siz[now]=1;
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==lat) continue;
		mindis[u]=std::min(mindis[now],1LL*b[i].val);
		dfs1(u,now);
		siz[now]+=siz[u];
		if(son[now]==0 || siz[u]>siz[son[now]]) son[now]=u;
	}
}
int dfn[maxn],dfnc,tp[maxn],stk[maxn],cnt,sta[maxn],js;
void dfs2(rg int now,rg int top){
	tp[now]=top;
	dfn[now]=++dfnc;
	if(son[now]) dfs2(son[now],top);
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==son[now] || u==fa[now]) continue;
		dfs2(u,u);
	}
}
bool cmp(rg int aa,rg int bb){
	return dfn[aa]<dfn[bb];
}
int get_lca(rg int u,rg int v){
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) std::swap(u,v);
		u=fa[tp[u]];
	}
	if(dep[u]<dep[v]) return u;
	else return v;
}
void init(rg int now){
	rg int lca=get_lca(now,sta[js]);
	while(1){
		if(dfn[lca]>=dfn[sta[js-1]]){
			if(lca!=sta[js]){
				ad2(sta[js],lca);
				ad2(lca,sta[js]);
				if(lca!=sta[js-1]){
					sta[js]=lca;
				} else {
					js--;
				}
			}
			break;
		} else {
			ad2(sta[js],sta[js-1]);
			ad2(sta[js-1],sta[js]);
			js--;
		}
	}
	sta[++js]=now;
}
bool vis[maxn];
long long dfs(rg int now,rg int lat){
	rg long long ans=0,cs=0;
	for(rg int i=h2[now];i!=-1;i=b2[i].nxt){
		rg int u=b2[i].to;
		if(u==lat) continue;
		ans+=dfs(u,now);
	}
	if(vis[now]){
		cs=mindis[now];
	} else {
		cs=std::min(mindis[now],ans);
	}
	vis[now]=0;
	h2[now]=-1;
	return cs;
}
int main(){
	memset(h,-1,sizeof(h));
	memset(h2,-1,sizeof(h2));
	memset(mindis,0x7f,sizeof(mindis));
	n=read();
	rg int aa,bb,cc;
	for(rg int i=1;i<n;i++){
		aa=read(),bb=read(),cc=read();
		ad(aa,bb,cc);
		ad(bb,aa,cc);
	}
	dfs1(1,0);
	dfs2(1,1);
	sta[0]=1;
	m=read();
	for(rg int i=1;i<=m;i++){
		cnt=read();
		t2=1;
		for(rg int j=1;j<=cnt;j++){
			aa=read();
			stk[j]=aa;
			vis[aa]=1;
		}
		std::sort(stk+1,stk+cnt+1,cmp);
		sta[js=1]=stk[1];
		for(rg int j=2;j<=cnt;j++){
			init(stk[j]);
		}
		while(js>0){
			ad2(sta[js],sta[js-1]);
			ad2(sta[js-1],sta[js]);
			js--;
		}
		printf("%lld\n",dfs(1,0));
	}
	return 0;
}