[学习笔记] FFT

填一下\(FFT\) 的坑。

我们的目标是求\(A\times B\)的系数(\(A,B\)为系数已知的多项式)

关于 \(\omega\)

大概讲一下

我们有复数 \(\omega\)

\(\omega_n^k=\cos \frac{2k\pi}{n}+i\times\sin\frac{2k\pi}{n}\)

可以把它看作把单位圆进行n等分,然后进行标号,\(k\) 每增加1,即使逆时针往后移一个

这东西有几个性质:

  • \(\omega_n^{k+n}=\omega_n^k\)

相当于转了一圈又回来了

  • \(\omega_n^{k+\frac n 2}=-\omega_n^k\)

转到对面去了

  • \(\omega_n^k=\omega_{\frac n 2}^{\frac k 2}\)

进行了缩放,但还是在同一个位置

大体思路

对于一个n-1次的多项式\(A(x)\),我们如果知道它的\(n\)个不同的\(x\)对应的值(我们称它为点值),那么我们就能还原它的系数。

  • 如果我们要求 \(A(x)\times B(x)\) 的系数,那么我们就可以先算出 \(A\)\(B\)\(n\)\(x\)处对应的点值,将两者相乘,这即是\(A\times B\)\(x\)处的点值,然后将它还原即可。

\(FFT\)就用到了这一点。

总体思路就是:

  • 我们首先将 \(\omega_n^0,\omega_n^1,\omega_n^2,…,\omega_n^{n-1}\) 代入两多项式,算出各多项式的点值(我们称其为 \(DFT\) 操作),再将两者相乘,将得到的值还原为系数(我们称其为 \(iDFT\) 操作)。

\(DFT\)

\(\begin{aligned}A(\omega_n^k)&=\sum_{i=0}^{n-1}a_0\times\omega_n^{ik}\\&=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i}\times(\omega_n^2)^{ik}+\omega_n^1\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i+1}\times(\omega_n^2)^{ik}\end{aligned}\)

  • 进行奇偶分类:

\(A_0(\omega_n^{k})=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i}\times(\omega_n^{k})^i,A_1(\omega_n^{k})=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i+1}\times(\omega_n^{k})^i\)

  • 继续把式子推下去:

\(k\le\frac n 2\)

\(\begin{aligned}A(\omega_n^k)&=A_0(\omega_n^{k})^2+\omega_n^kA_1(\omega_n^{k})^2\\&=A_0(\omega_{\frac n 2}^{k})+\omega_n^kA_1(\omega_{\frac n 2}^{k})\end{aligned}\)

\(\begin{aligned}A(\omega_n^{k+\frac n 2})&=A_0(\omega_n^{k+\frac n 2})^2+\omega_n^{k}A_1(\omega_n^{k+\frac n 2})^2\\&=A_0(\omega_{\frac n 2}^{k+\frac n 2})+\omega_n^{k}A_1(\omega_{\frac n 2}^{k+\frac n 2})\\&=A_0(\omega_{\frac n 2}^k)-\omega_n^{k}A_1(\omega_{\frac n 2}^k)\end{aligned}\)

有了这个式子我们就能迭代求解了。

时间复杂度是\(\varTheta(n\log n)\) 的。

\(iDFT\)

进入\(iDFT\) 的过程。

\(C=A\times B\)

考虑:

\(\begin{aligned}C(\omega_n^k)&=DFT.A(\omega_n^k)\times DFT.B(\omega_n^k)\\\sum_{i=0}^{n-1}c[i]\times(\omega_n^k)^i&=DFT.A(\omega_n^k)\times DFT.B(\omega_n^k)\\\sum_{i=0}^{n-1}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}c[j]\times(\omega_n^i)^j\times(\omega_n^{-k})^i\\\sum_{i=0}^{n-1}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=\sum_{j=0}^{n-1}c[j]\sum_{i=0}^{n-1}\omega_n^{i(j-k)}\end{aligned}\)

  • \(j=k\) 时:

\(\omega_n^{i(j-k)}=\omega_n^0=1\)

\(\sum_{i=0}^{n-1}\omega_n^{i(j-k)}=n\)

  • \(j\ne k\) 时:

\(\begin{aligned}\sum_{i=0}^{n-1}\omega_n^{i(j-k)}&=\sum_{i=0}^{\lfloor\frac{n-1} 2\rfloor}\omega_n^{i(j-k)}+\omega_n^{(n-i)(j-k)}\\&=\sum_{i=0}^{\lfloor\frac{n-1}2\rfloor}\omega_n^{i(j-k)}+\omega_n^{-i(j-k)}\\&=0\end{aligned}\)

  • 那么我们就有:

\(\begin{aligned}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=c[k]\times n\\c[k]&=\frac{DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i}n\end{aligned}\)

  • 所以我们只需要把\(DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\) 当作系数,将\(\omega_n^{-k}\)代入\(DFT\)的过程,求出来的点值除以\(n\),就是\(A\times B\)的系数。

代码:

#include<bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int N=2100005; double pi=acos(-1);
cp a[N],b[N],c[N];
int n,m,lg;
void FFT(cp *a,int n,int inv){
	if (n==1) return;
	int m=n/2;
	cp a0[n],a1[n];
	for (int i=0;i<m;i++)
		a0[i]=a[i*2],a1[i]=a[i*2+1];
	FFT(a0,m,inv); FFT(a1,m,inv);
	cp W(cos(pi/m),inv*sin(pi/m)),w(1,0);
	for (int i=0;i<m;i++,w*=W){
		a[i]=a0[i]+w*a1[i];
		a[i+m]=a0[i]-w*a1[i];
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for (int i=0;i<=n;i++) scanf("%lf",&a[i]);
	for (int i=0;i<=m;i++) scanf("%lf",&b[i]);
	int lg=1;
	while (lg<n+m+1) lg<<=1;
	FFT(a,lg,1); FFT(b,lg,1);
	for (int i=0;i<lg;i++) c[i]=a[i]*b[i];
	FFT(c,lg,-1);
	for (int i=0;i<n+m+1;i++) printf("%.0f ",fabs(c[i].real()/lg));
}
  • 我们可以用一个“蝴蝶变换”,把递归变成循环

至于蝴蝶变换是什么嘛。。。

硬证,能证,意义不大,当个规律记住就完事儿了。

  • 代码(蝴蝶变换):
#include<bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int N=2100005; double pi=acos(-1);
cp a[N],b[N],c[N],u,t;
int n,m,x,lg;
void FFT(cp *a,int n,int opt){
	for (int i=0;i<n;i++){
		int t=0;
		for (int j=0;j<lg;j++)
			if ((i>>j)&1) t|=(1<<(lg-j-1));
		if (i<t) swap(a[i],a[t]);
	}
	for (int m=1;m<n;m<<=1){
		cp omega(cos(pi/m),opt*sin(pi/m));
		for (int k=0;k<n;k+=m*2){
			cp now(1,0);
			for (int j=0;j<m;j++){
				u=a[k+j],t=now*a[k+j+m];
				a[k+j]=u+t;
				a[k+j+m]=u-t;
				now*=omega;
			}
		}
	}
}
int main(){
	scanf("%d%d",&n,&m);
	n++,m++;
	for (int i=0;i<n;i++) scanf("%lf",&a[i]);
	for (int i=0;i<m;i++) scanf("%lf",&b[i]);
	x=1;
	while (x<n+m-1) x<<=1,lg++;
	FFT(a,x,1); FFT(b,x,1);
	for (int i=0;i<x;i++) c[i]=a[i]*b[i];
	FFT(c,x,-1);
	for (int i=0;i<n+m-1;i++) printf("%.0f ",fabs(c[i].real()/x));
}

\(FFT的优化\)

  • 注意到我们平时做FFT的时候只用了复数的实部,那么如果我们把虚部也利用一下,是不是会更优呢?

设:\(\begin{aligned}P(x)&=A(x)+i\times B(x)\\Q(x)&=A(x)-i\times B(x)\end{aligned}\)

看起来非常对称,根据经验,如果满足对称性,数学往往会给你一些奖励。

  • 那么现在我们对\(P\)\(Q\)求个\(DFT\)

\(\begin{aligned}
DFT.P(\omega_n^p)&=\sum_{k=0}^{n-1}a[k]\times(\omega_n^p)^k+i\times b[k]\times(\omega_n^p)^k\\
&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}-b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}+b[k]\times\cos\frac{pk}{2n\pi})\\
DFT.Q(\omega_n^p)&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}+b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}-b[k]\times\cos\frac{pk}{2n\pi})
\end{aligned}\)

  • \(DFT.P\)稍微做一个转化:

\(\begin{aligned}
DFT.P(\omega_n^p)&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}-b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}+b[k]\times\cos\frac{pk}{2n\pi})\\
&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{(-p)k}{2n\pi}+b[k]\times\sin\frac{(-p)k}{2n\pi})-i\times(a[k]\times\sin\frac{(-p)k}{2n\pi}-b[k]\times\cos\frac{(-p)k}{2n\pi})\\
&=conj(DFT.Q(\omega_n^{-p}))\\
&=conj(DFT.Q(\omega_n^{n-p}))
\end{aligned}\)

其中:\(conj(x)\)是共轭。(\(conj(a+ib)=a-ib\)

  • 我们发现:只要对\(P\)做一个\(DFT\),同时也能求出\(DFT.Q\)

  • 接着还原\(A,B\)的点值:

\(\begin{aligned}DFT.A(x)=\frac{DFT.P(x)+DFT.Q(x)}2\end{aligned}\)

\(\begin{aligned}DFT.B(x)=\frac{DFT.P(x)-DFT.Q(x)}{2i}\end{aligned}\)


现在,我们优化了\(DFT\)的过程,我们能否也优化一下\(iDFT\)的过程呢

设:\(P(x)=DFT.A(x)+i\times DFT.B(x)\)

(这里的\(DFT\)是乘完后的\(DFT\),即\(DFT.A(x)=DFT.A_1(x)\times DFT.A_2(x)\),而我们要求\(A_1\times A_2\),而\(B\)则是另一组多项式在做乘法)

  • \(\begin{aligned}
    DFT.P(\omega_n^{-p})&=\sum_{k=0}^{n-1}(\omega_n^{-p})^k\times(DFT.A(\omega_n^p)+i\times DFT.B(\omega_n^p))\\
    &=\sum_{k=0}^{n-1}(\omega_n^{-p})^k\times(\sum_{j=0}^{n-1}a[j]\times \omega_n^{kj}+i\times b[j]\times\omega_n^{kj})\\
    &=\sum_{j=0}^{n-1}(a[j]+i\times b[j])\sum_{k=0}^{n-1}\omega_n^{k(j-p)}\\
    &=\frac{a[p]+i\times b[p]}n
    \end{aligned}\)

(最后一步转换就是和之前类似的分类讨论)

\(a,b\) 为乘完后的多项式系数)

可以发现,我们\(A,B\)的系数仍然分别保留在复数的实部和虚部,接着只要将它还原即可。


有了这个优化,对于多项式乘法,比如\(A\times B\),我们总共只需要进行两次\(DFT\)即可。

当然,只处理一个多项式乘法,并不需要用到我们对\(iDFT\)的优化,我们对于\(iDFT\)的优化,是在同时处理两组及以上的多项式相乘时才有效果的。

拆系数\(FFT\)

我们在处理要求取模的问题时,可以用到拆系数\(FFT\)

假设我们要求\(A\times B\)的系数(\(\mod p\)

  • 拆掉原系数

\(\begin{aligned}&a[i]=A[i]>>15\\&b[i]=A[i]\&(2^{15}-1)\\&c[i]=B[i]>>15\\&d[i]=B[i]\&(2^{15}-1)\end{aligned}\)

  • 分别求出他们的点值(要取模)

  • 将次数相同的项组合:

\(Ans_1=D.a\times D.c\)\(Ans_2=D.b\times D.c+D.a\times D.d\)\(Ans_3=D.b\times D.d\)

\(DFT.x\)被简写为\(D.x\)

  • \(Ans_1,Ans_2,Ans_3\)进行\(iDFT\)

  • 还原出系数

\(\begin{aligned}Ans[x]=iDFT.Ans_1[x]\times 2^{30}+iDFT.Ans_2\times2^{15}+iDFT.Ans_3\end{aligned}\)

(要取模)

这里没啥好讲的,就来简单讲一下为什么拆完之后分别做\(DFT\)\(iDFT\),再分别取模还是对的吧(挺显然的,其实)

我们把原系数拆分之后再做\(DFT\),在\(DFT\),重新组和,\(iDFT\)的过程中,是不会爆\(long\) \(long\) 的,也就是说我们求出的就是未经过取模的真实点值。

有了点值,再进行\(iDFT\),我们还原出的系数也就是将原系数进行拆分后的系数。

然后按照原来的比例相乘组合即可,在此时取模是没有影响的,得到也即是真正的系数取模后的值。

  • 代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1<<19|5; long double pi=acos(-1);
struct cp{
    long double a,b;
    cp() {a=b=0;}
    cp(long double _a, long double _b) {a=_a,b=_b;}
    cp operator +(cp x) const {return cp(a+x.a,b+x.b);}
    cp operator -(cp x) const {return cp(a-x.a,b-x.b);}
    cp operator *(cp x) const {return cp(a*x.a-b*x.b,a*x.b+b*x.a);}
    cp Conj() {return cp(a,-b);}
};
cp p[N],q[N],a[N],b[N],c[N],d[N],kx,bx,ky,by,u,t;
int ret[N],L1,L2;
ll x[N],y[N],ans[N],mod,len,lg;
void FFT(cp *a,int n,long double opt){
    int m=1;
    for (int i=0;i<n;i++)
        if (i<ret[i]) swap(a[i],a[ret[i]]);
    for (int s=1;s<=lg;s++){
        cp omega(cos(pi/m),opt*sin(pi/m));
        m<<=1;
        for (int k=0;k<n;k+=m){
            cp now(1,0);
            for (int j=0;j<m/2;j++){
                u=a[k+j]; t=a[k+j+m/2]*now;
                a[k+j]=u+t;
                a[k+j+m/2]=u-t;
                now=now*omega;
            }
        }
    }
}
void mul(ll *x,ll *y,ll *ans,int n){
    cp r(0.5,0),h(0,-0.5),o(0,1);
    for (int i=0;i<n;i++) 
        ret[i]=(ret[i>>1]>>1)|((i&1)<<(lg-1));
    for (int i=0;i<=L1;i++)
        p[i]=cp(x[i]>>15,x[i]&32767);
    for (int i=0;i<=L2;i++)
        q[i]=cp(y[i]>>15,y[i]&32767);
    FFT(p,n,1); FFT(q,n,1);
    for (int i=0;i<n;i++){
        int j=(n-i)&(n-1);
        kx=cp(p[i]+p[j].Conj())*r;
        bx=cp(p[i]-p[j].Conj())*h;
        ky=cp(q[i]+q[j].Conj())*r;
        by=cp(q[i]-q[j].Conj())*h;
        a[i]=kx*ky; b[i]=kx*by; c[i]=bx*ky; d[i]=bx*by;
    }
    for (int i=0;i<n;i++){
        p[i]=a[i]+(b[i]+c[i])*o;
        q[i]=d[i];
    }
    FFT(p,n,-1); FFT(q,n,-1);
    for (int i=0;i<n;i++){
        ll a,b,c;
        a=(ll)(p[i].a/n+0.5)%mod;
        b=(ll)(p[i].b/n+0.5)%mod;
        c=(ll)(q[i].a/n+0.5)%mod;
        ans[i]=(((a<<30)%mod+(b<<15)%mod)+c)%mod;
        if (ans[i]<0) ans+=mod;
    }
}
int main(){
    scanf("%d%d%lld",&L1,&L2,&mod);
    for (int i=0;i<=L1;i++) scanf("%lld",&x[i]),x[i]%=mod;
    for (int i=0;i<=L2;i++) scanf("%lld",&y[i]),y[i]%=mod;
    for (len=1;len<=L1+L2+1;len<<=1) lg++;
    mul(x,y,ans,len);
    for (int i=0;i<L1+L2+1;i++) printf("%lld ",ans[i]);
}