­

wqs二分學習筆記

內容

\(wqs\) 二分又名凸優化、帶權二分。

一般用於 \(n\) 個物品強制選 \(k\) 個的情況下的最優化問題。

這樣的問題直接 \(dp\) 複雜度一般都比較高,因為要求強制選 \(k\) 個,所以要有一維來記錄選了多少物品。

\(wqs\) 二分則可以把這種限制去掉。

首先我們二分一個權值 \(C\),強行給每一個物品都加上這一個權值。

然後跑一遍沒有選 \(k\) 個物品的限制的 \(dp\)

最後根據最優值所選擇的物品個數來調整二分端點。

能夠用 \(wqs\) 二分優化的 \(dp\) 要滿足 \(dp\) 得到的結果是凸的。

也就是說,如果把橫坐標看作強制選擇的物品個數,縱坐標看作函數值,那麼相鄰兩點之間的斜率應該是單調的。

之所以要有這個限制,是因為我們二分的附加權值實際上是斜率。

假設要求的是最大值,我們拿一條斜率為 \(k\) 的直線去切這個凸包,那麼切到的點的截距一定是最大的。

但是我們並不知道我們具體切到了哪一個點,所以需要去計算。

根據直線的斜截式 \(y=kx+b\),截距 \(b=y-kx\)

我們可以把截距 \(b\) 也看成一個一次函數,那麼如果能求出 \(b\) 的最值也就知道了當前切到的點的橫坐標。

觀察 \(b\) 的表達式,實際上就相當於給每一種物品減去了一個權值。

所以我們只要給物品減去權值之後跑一次不帶限制的 \(dp\),求出最優的情況了選擇了幾個物品,就能知道切到的是哪一個點了。

但是有的時候會出現斜率相等的情況,這是就需要我們強制規定選橫坐標最大/小的點。

例題

P2619 [國家集訓隊2]Tree I

題目傳送門

分析

求恰好有 \(k\) 條白邊的最小生成樹。

可以給每一條白邊加上額外的邊權去跑最小生成樹。

如果得到的生成樹中白邊比想要的多,就說明我們加的權值少了,要多加點,否則就少加點。

在斜率相等時,我們強制選擇白邊,也就是橫坐標最大的點。

程式碼

#include<cstdio>
#include<algorithm>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=1e5+5;
int n,m,k,ans,fa[maxn],cnt,whicnt,sum;
struct asd{
	int zb,yb,val,jud;
}b[maxn];
bool cmp(rg asd aa,rg asd bb){
	return aa.val==bb.val?aa.jud<bb.jud:aa.val<bb.val;
}
int zhao(rg int xx){
	if(xx==fa[xx]) return xx;
	return fa[xx]=zhao(fa[xx]);
}
bool jud(rg int val){
	ans=cnt=whicnt=0;
	for(rg int i=1;i<=m;i++) b[i].val+=(!b[i].jud)*val;
	for(rg int i=1;i<=n;i++) fa[i]=i;
	std::sort(b+1,b+m+1,cmp);
	rg int aa,bb;
	for(rg int i=1;i<=m;i++){
		aa=b[i].zb,bb=b[i].yb;
		aa=zhao(aa),bb=zhao(bb);
		if(aa==bb) continue;
		whicnt+=(b[i].jud==0),cnt++,fa[aa]=bb,ans+=b[i].val;
		if(cnt==n-1) break;
	}
	for(rg int i=1;i<=m;i++) b[i].val-=(!b[i].jud)*val;
	return whicnt>=k;
}
int main(){
	n=read(),m=read(),k=read();	
	for(rg int i=1;i<=m;i++) b[i].zb=read()+1,b[i].yb=read()+1,b[i].val=read(),b[i].jud=read();
	rg int l=-200,r=200,mids;
	while(l<=r){
		mids=(l+r)>>1;
		if(jud(mids)){
			l=mids+1;
			sum=ans-k*mids;
		} else {
			r=mids-1;
		}
	}
	printf("%d\n",sum);
	return 0;
}

P4383 [八省聯考2018]林克卡特樹

題目傳送門

分析

實際上是讓你從樹上選擇 \(k+1\) 條點不相交的鏈,使權值最大。

考慮 \(60\) 分的 \(dp\) 做法。

\(f[i][j][0/1/2]\) 為在 \(i\) 的子樹中選擇了 \(j\) 條鏈,\(i\) 的度數為 \(0,1,2\) 時的最大值。

之所以要加上度數的限制是為了合併子樹的時候能夠更好地處理資訊。

度數為 \(0\) 代表當前點不在鏈上

度數為 \(1\) 代表當前點是鏈的一個端點

度數為 \(2\) 代表當前點在一條鏈的中心

每一次轉移之後,我們都令 \(f[now][j][0]=max(f[now][j][0],max(f[now][j][2],f[now][j-1][1]))\)

這樣我們在更新父親節點的時候就不用特判很多情況

\(u\)\(now\) 的兒子,\(val\) 代表邊權

\(f[now][j][2]\) 可以由 \(f[now][k][2]+f[u][j-k][0]\)\(f[now][k][1]+f[u][j-k-1][1]+val\) 更新而來

含義分別是繼承之前的資訊,當前點所在的鏈的一段與兒子節點所在的鏈的一端拼和成一條新的鏈並且當前點處在鏈的中央

\(f[now][j][1]\) 可以由 \(f[now][k][1]+f[u][j-k][0]\)\(f[now][k][0]+f[u][j-k-1][1]+val\) 更新而來

含義分別是繼承之前的資訊,當前邊與兒子節點所在的鏈的一端拼和成一條新的鏈並且讓當前節點作為鏈的一端

\(f[now][j][0]\) 直接繼承 \(f[now][k][0]+f[u][j-k][0]\) 即可

一開始的時候要把一個節點也當鏈處理,即 \(f[now][0][0]=f[now][0][1]=f[now][1][2]=0\)

打表可得函數值是一個凸函數,斜率單調不增

所以可以用 \(wqs\) 二分優化

每次強制給每一條鏈加上一個權值,算一下最優的情況下選擇了多少鏈

如果選擇的鏈比想要的多,那麼增加附加權值,少選一些

否則減小附加權值,多選一些

斜率相等的時候強制選擇最左邊的點

注意一下數組更新的順序就行了

程式碼

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=3e5+5;
typedef long long ll;
int h[maxn],tot=1,n,k;
struct asd{
	int to,nxt,val;
}b[maxn<<1];
void ad(rg int aa,rg int bb,rg int cc){
	b[tot].to=bb;
	b[tot].nxt=h[aa];
	b[tot].val=cc;
	h[aa]=tot++;
}
struct jie{
	int cnt;
	ll val;
	jie(){}
	jie(rg int aa,rg ll bb){
		cnt=aa,val=bb;
	}
	friend jie operator + (const jie& A,const jie& B){
		return jie(A.cnt+B.cnt,A.val+B.val);
	}
	friend bool operator < (const jie& A,const jie& B){
		if(A.val==B.val) return A.cnt<B.cnt;
		return A.val<B.val;
	}
}f[maxn][3];
jie Max(rg jie aa,rg jie bb){
	return aa<bb?bb:aa;
}
void dfs(rg int now,rg int lat,rg ll val){
	f[now][1]=f[now][0]=jie(0,0);
	f[now][2]=jie(1,val);
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==lat) continue;
		dfs(u,now,val);
		f[now][2]=Max(f[now][2]+f[u][0],f[now][1]+f[u][1]+jie(1,b[i].val+val));
		f[now][1]=Max(f[now][0]+f[u][1]+jie(0,b[i].val),f[now][1]+f[u][0]);
		f[now][0]=Max(f[now][0],f[now][0]+f[u][0]);
	}
	f[now][0]=Max(f[now][0],Max(f[now][2],f[now][1]+jie(1,val)));
}
void init(){
	for(rg int i=1;i<=n;i++){
		f[i][0].cnt=f[i][1].cnt=f[i][2].cnt=0;
		f[i][0].val=f[i][1].val=f[i][2].val=-0x3f3f3f3f3f3f3f3f;
	}
}
int main(){
	memset(h,-1,sizeof(h));
	n=read(),k=read();
	rg int aa,bb,cc;
	for(rg int i=1;i<n;i++){
		aa=read(),bb=read(),cc=read();
		ad(aa,bb,cc);
		ad(bb,aa,cc);
	}
	k++;
	rg long long l=-3e11,r=3e11,mids,ans;
	while(l<=r){
		mids=(l+r)>>1;
		init();
		dfs(1,0,mids);
		if(f[1][0].cnt<k){
			l=mids+1;
		} else {
			ans=f[1][0].val-1LL*mids*k;
			r=mids-1;
		}
	}
	printf("%lld\n",ans);
	return 0;
}