DP搬运工1 [来自yyy–mengbier的预设型dp]

DP搬运工1

题目描述

给你 \(n,K\) ,求有多少个 \(1\)\(n\) 的排列,满足相邻两个数的 \(max\) 的和不超过 \(K\)

输入格式

一行两个整数 \(n,K\)

输出格式

一行一个整数 \(ans\) 表示答案 \(mod\ 998244353\)

样例

样例输入 1

4 10

样例输出 1

16

样例输入 2

10 66

样例输出 2

1983744

数据范围与提示

\(50\) 个测试点,第 \(i\) 个测试点为 \(n=i\)\(K \leqslant n^2\)

分析

用学长的题解来说这个叫做预设性 \(dp\) (其实也不知道啥意思)。意思大概就是枚举的是当前放哪个数(因为这几个题貌似都是这样)

这个题我们考虑往里边插入数,因为每一次要取 \(max\) ,所以我们根据当前插入的值两边还可不可以放数来进行转移。

如果可以放入一个数,那么当前这个数之对和贡献一次。

如果两边可以放入两个数,那么这个数是没有贡献的。

如果两边都不放数,那么它贡献两次。

所以我们定义 \(f[i][j][k]\) 为放到第 \(i\) 个数,可以放的位置有 \(j\) 个。和为 \(k\)

因为可以放在序列中,也可以放在两端,所以我们分开来考虑。

放在两端的时候就没有两边放两个数的情况了,但是两端有两种情况,所以加的时侯 \(f[i-1][j][k]\) 需要乘以 \(2\)

放在中间就需要考虑了,但是只有在当前数两边放一个的时候才用乘以 \(2\) ,所以我们就可以愉快的转移了。

代码

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
//以下好多行是卡常
const int L=1<<20;
char buffer[L],*S,*T;
#define lowbit(x) (x & -x)
#define getchar() (S==T&&(T=(S=buffer)+fread(buffer,1,L,stdin),S==T)?EOF:*S++)
#define inline __inline__ __attribute__((__always_inline__))
#define max(a,b) (a>b?a:b)
#define re register 
const int maxn = 52;
const int mod = 998244353;
int f[maxn][maxn][maxn*maxn];
int n;
inline int read(){
	int s = 0,f = 1;
	char ch = getchar();
	while(!isdigit(ch)){
		if(ch == '-')f = -1;
		 ch = getchar();
	}
	while(isdigit(ch)){
		s = s * 10 + ch - '0';
		ch = getchar();
	}
	return s * f;
}
int main(){
	n = read();
	int K = read();
	f[1][0][0] = 1;
	for(int i = 2;i <= n;++i){
		int jl1 = min(i,n-i)+1;	//找到当前最多有多少位置能放
		int jl2 = min(K,i*i);//找到当前最大的和
		for(int j = 0;j <= jl1; ++j){
			for(int k = 0;k <= jl2; ++k){
				if(!f[i-1][j][k])continue;
				int jl = f[i-1][j][k] * 2 % mod;//第一种放在两端的情况
				f[i][j+1][k] = (f[i][j+1][k] + jl) % mod;
				f[i][j][k+i] = (f[i][j][k+i] + jl) % mod;
				if(!j)continue;
				jl = f[i-1][j][k] * 1ll * j % mod;//以下是放在序列中间的情况
				f[i][j+1][k] = (f[i][j+1][k] + jl) % mod;
				f[i][j][k+i] = (f[i][j][k+i] + jl * 2ll % mod) % mod;
				f[i][j-1][k+2*i] = (f[i][j-1][k+2*i] + jl) % mod;
			}
		}
	}
	int ans = 0;
	for(int i = 0;i <= K;++i){//把小于等于 K 的所有情况都加起来
		ans = (ans + f[n][0][i]) % mod;
	}
	printf("%d\n",ans);
}