[學習筆記] FFT
填一下\(FFT\) 的坑。
我們的目標是求\(A\times B\)的係數(\(A,B\)為係數已知的多項式)
關於 \(\omega\)
大概講一下
我們有複數 \(\omega\)。
\(\omega_n^k=\cos \frac{2k\pi}{n}+i\times\sin\frac{2k\pi}{n}\)
可以把它看作把單位圓進行n等分,然後進行標號,\(k\) 每增加1,即使逆時針往後移一個
這東西有幾個性質:
- \(\omega_n^{k+n}=\omega_n^k\)
相當於轉了一圈又回來了
- \(\omega_n^{k+\frac n 2}=-\omega_n^k\)
轉到對面去了
- \(\omega_n^k=\omega_{\frac n 2}^{\frac k 2}\)
進行了縮放,但還是在同一個位置
大體思路
對於一個n-1次的多項式\(A(x)\),我們如果知道它的\(n\)個不同的\(x\)對應的值(我們稱它為點值),那麼我們就能還原它的係數。
- 如果我們要求 \(A(x)\times B(x)\) 的係數,那麼我們就可以先算出 \(A\) 和 \(B\) 在\(n\)個\(x\)處對應的點值,將兩者相乘,這即是\(A\times B\)在\(x\)處的點值,然後將它還原即可。
而\(FFT\)就用到了這一點。
總體思路就是:
- 我們首先將 \(\omega_n^0,\omega_n^1,\omega_n^2,…,\omega_n^{n-1}\) 代入兩多項式,算出各多項式的點值(我們稱其為 \(DFT\) 操作),再將兩者相乘,將得到的值還原為係數(我們稱其為 \(iDFT\) 操作)。
\(DFT\)
\(\begin{aligned}A(\omega_n^k)&=\sum_{i=0}^{n-1}a_0\times\omega_n^{ik}\\&=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i}\times(\omega_n^2)^{ik}+\omega_n^1\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i+1}\times(\omega_n^2)^{ik}\end{aligned}\)
- 進行奇偶分類:
\(A_0(\omega_n^{k})=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i}\times(\omega_n^{k})^i,A_1(\omega_n^{k})=\sum_{i=0}^{\lfloor\frac n 2\rfloor}a_{2i+1}\times(\omega_n^{k})^i\)
- 繼續把式子推下去:
(\(k\le\frac n 2\))
\(\begin{aligned}A(\omega_n^k)&=A_0(\omega_n^{k})^2+\omega_n^kA_1(\omega_n^{k})^2\\&=A_0(\omega_{\frac n 2}^{k})+\omega_n^kA_1(\omega_{\frac n 2}^{k})\end{aligned}\)
\(\begin{aligned}A(\omega_n^{k+\frac n 2})&=A_0(\omega_n^{k+\frac n 2})^2+\omega_n^{k}A_1(\omega_n^{k+\frac n 2})^2\\&=A_0(\omega_{\frac n 2}^{k+\frac n 2})+\omega_n^{k}A_1(\omega_{\frac n 2}^{k+\frac n 2})\\&=A_0(\omega_{\frac n 2}^k)-\omega_n^{k}A_1(\omega_{\frac n 2}^k)\end{aligned}\)
有了這個式子我們就能迭代求解了。
時間複雜度是\(\varTheta(n\log n)\) 的。
\(iDFT\)
進入\(iDFT\) 的過程。
設 \(C=A\times B\)
考慮:
\(\begin{aligned}C(\omega_n^k)&=DFT.A(\omega_n^k)\times DFT.B(\omega_n^k)\\\sum_{i=0}^{n-1}c[i]\times(\omega_n^k)^i&=DFT.A(\omega_n^k)\times DFT.B(\omega_n^k)\\\sum_{i=0}^{n-1}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}c[j]\times(\omega_n^i)^j\times(\omega_n^{-k})^i\\\sum_{i=0}^{n-1}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=\sum_{j=0}^{n-1}c[j]\sum_{i=0}^{n-1}\omega_n^{i(j-k)}\end{aligned}\)
- 當 \(j=k\) 時:
\(\omega_n^{i(j-k)}=\omega_n^0=1\)
\(\sum_{i=0}^{n-1}\omega_n^{i(j-k)}=n\)
- 當 \(j\ne k\) 時:
\(\begin{aligned}\sum_{i=0}^{n-1}\omega_n^{i(j-k)}&=\sum_{i=0}^{\lfloor\frac{n-1} 2\rfloor}\omega_n^{i(j-k)}+\omega_n^{(n-i)(j-k)}\\&=\sum_{i=0}^{\lfloor\frac{n-1}2\rfloor}\omega_n^{i(j-k)}+\omega_n^{-i(j-k)}\\&=0\end{aligned}\)
- 那麼我們就有:
\(\begin{aligned}DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i&=c[k]\times n\\c[k]&=\frac{DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\times(\omega_n^{-k})^i}n\end{aligned}\)
- 所以我們只需要把\(DFT.A(\omega_n^i)\times DFT.B(\omega_n^i)\) 當作係數,將\(\omega_n^{-k}\)代入\(DFT\)的過程,求出來的點值除以\(n\),就是\(A\times B\)的係數。
程式碼:
#include<bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int N=2100005; double pi=acos(-1);
cp a[N],b[N],c[N];
int n,m,lg;
void FFT(cp *a,int n,int inv){
if (n==1) return;
int m=n/2;
cp a0[n],a1[n];
for (int i=0;i<m;i++)
a0[i]=a[i*2],a1[i]=a[i*2+1];
FFT(a0,m,inv); FFT(a1,m,inv);
cp W(cos(pi/m),inv*sin(pi/m)),w(1,0);
for (int i=0;i<m;i++,w*=W){
a[i]=a0[i]+w*a1[i];
a[i+m]=a0[i]-w*a1[i];
}
}
int main(){
scanf("%d%d",&n,&m);
for (int i=0;i<=n;i++) scanf("%lf",&a[i]);
for (int i=0;i<=m;i++) scanf("%lf",&b[i]);
int lg=1;
while (lg<n+m+1) lg<<=1;
FFT(a,lg,1); FFT(b,lg,1);
for (int i=0;i<lg;i++) c[i]=a[i]*b[i];
FFT(c,lg,-1);
for (int i=0;i<n+m+1;i++) printf("%.0f ",fabs(c[i].real()/lg));
}
- 我們可以用一個「蝴蝶變換」,把遞歸變成循環
至於蝴蝶變換是什麼嘛。。。
硬證,能證,意義不大,當個規律記住就完事兒了。
- 程式碼(蝴蝶變換):
#include<bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int N=2100005; double pi=acos(-1);
cp a[N],b[N],c[N],u,t;
int n,m,x,lg;
void FFT(cp *a,int n,int opt){
for (int i=0;i<n;i++){
int t=0;
for (int j=0;j<lg;j++)
if ((i>>j)&1) t|=(1<<(lg-j-1));
if (i<t) swap(a[i],a[t]);
}
for (int m=1;m<n;m<<=1){
cp omega(cos(pi/m),opt*sin(pi/m));
for (int k=0;k<n;k+=m*2){
cp now(1,0);
for (int j=0;j<m;j++){
u=a[k+j],t=now*a[k+j+m];
a[k+j]=u+t;
a[k+j+m]=u-t;
now*=omega;
}
}
}
}
int main(){
scanf("%d%d",&n,&m);
n++,m++;
for (int i=0;i<n;i++) scanf("%lf",&a[i]);
for (int i=0;i<m;i++) scanf("%lf",&b[i]);
x=1;
while (x<n+m-1) x<<=1,lg++;
FFT(a,x,1); FFT(b,x,1);
for (int i=0;i<x;i++) c[i]=a[i]*b[i];
FFT(c,x,-1);
for (int i=0;i<n+m-1;i++) printf("%.0f ",fabs(c[i].real()/x));
}
\(FFT的優化\)
- 注意到我們平時做FFT的時候只用了複數的實部,那麼如果我們把虛部也利用一下,是不是會更優呢?
設:\(\begin{aligned}P(x)&=A(x)+i\times B(x)\\Q(x)&=A(x)-i\times B(x)\end{aligned}\)
看起來非常對稱,根據經驗,如果滿足對稱性,數學往往會給你一些獎勵。
- 那麼現在我們對\(P\),\(Q\)求個\(DFT\)
\(\begin{aligned}
DFT.P(\omega_n^p)&=\sum_{k=0}^{n-1}a[k]\times(\omega_n^p)^k+i\times b[k]\times(\omega_n^p)^k\\
&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}-b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}+b[k]\times\cos\frac{pk}{2n\pi})\\
DFT.Q(\omega_n^p)&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}+b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}-b[k]\times\cos\frac{pk}{2n\pi})
\end{aligned}\)
- 對\(DFT.P\)稍微做一個轉化:
\(\begin{aligned}
DFT.P(\omega_n^p)&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{pk}{2n\pi}-b[k]\times\sin\frac{pk}{2n\pi})+i\times(a[k]\times\sin\frac{pk}{2n\pi}+b[k]\times\cos\frac{pk}{2n\pi})\\
&=\sum_{k=0}^{n-1}(a[k]\times\cos\frac{(-p)k}{2n\pi}+b[k]\times\sin\frac{(-p)k}{2n\pi})-i\times(a[k]\times\sin\frac{(-p)k}{2n\pi}-b[k]\times\cos\frac{(-p)k}{2n\pi})\\
&=conj(DFT.Q(\omega_n^{-p}))\\
&=conj(DFT.Q(\omega_n^{n-p}))
\end{aligned}\)
其中:\(conj(x)\)是共軛。(\(conj(a+ib)=a-ib\))
-
我們發現:只要對\(P\)做一個\(DFT\),同時也能求出\(DFT.Q\)
-
接著還原\(A,B\)的點值:
\(\begin{aligned}DFT.A(x)=\frac{DFT.P(x)+DFT.Q(x)}2\end{aligned}\)
\(\begin{aligned}DFT.B(x)=\frac{DFT.P(x)-DFT.Q(x)}{2i}\end{aligned}\)
現在,我們優化了\(DFT\)的過程,我們能否也優化一下\(iDFT\)的過程呢
設:\(P(x)=DFT.A(x)+i\times DFT.B(x)\)
(這裡的\(DFT\)是乘完後的\(DFT\),即\(DFT.A(x)=DFT.A_1(x)\times DFT.A_2(x)\),而我們要求\(A_1\times A_2\),而\(B\)則是另一組多項式在做乘法)
- \(\begin{aligned}
DFT.P(\omega_n^{-p})&=\sum_{k=0}^{n-1}(\omega_n^{-p})^k\times(DFT.A(\omega_n^p)+i\times DFT.B(\omega_n^p))\\
&=\sum_{k=0}^{n-1}(\omega_n^{-p})^k\times(\sum_{j=0}^{n-1}a[j]\times \omega_n^{kj}+i\times b[j]\times\omega_n^{kj})\\
&=\sum_{j=0}^{n-1}(a[j]+i\times b[j])\sum_{k=0}^{n-1}\omega_n^{k(j-p)}\\
&=\frac{a[p]+i\times b[p]}n
\end{aligned}\)
(最後一步轉換就是和之前類似的分類討論)
(\(a,b\) 為乘完後的多項式係數)
可以發現,我們\(A,B\)的係數仍然分別保留在複數的實部和虛部,接著只要將它還原即可。
有了這個優化,對於多項式乘法,比如\(A\times B\),我們總共只需要進行兩次\(DFT\)即可。
當然,只處理一個多項式乘法,並不需要用到我們對\(iDFT\)的優化,我們對於\(iDFT\)的優化,是在同時處理兩組及以上的多項式相乘時才有效果的。
拆係數\(FFT\)
我們在處理要求取模的問題時,可以用到拆係數\(FFT\)。
假設我們要求\(A\times B\)的係數(\(\mod p\))
- 拆掉原係數
\(\begin{aligned}&a[i]=A[i]>>15\\&b[i]=A[i]\&(2^{15}-1)\\&c[i]=B[i]>>15\\&d[i]=B[i]\&(2^{15}-1)\end{aligned}\)
-
分別求出他們的點值(要取模)
-
將次數相同的項組合:
\(Ans_1=D.a\times D.c\),\(Ans_2=D.b\times D.c+D.a\times D.d\),\(Ans_3=D.b\times D.d\)
(\(DFT.x\)被簡寫為\(D.x\))
-
對\(Ans_1,Ans_2,Ans_3\)進行\(iDFT\)
-
還原出係數
\(\begin{aligned}Ans[x]=iDFT.Ans_1[x]\times 2^{30}+iDFT.Ans_2\times2^{15}+iDFT.Ans_3\end{aligned}\)
(要取模)
這裡沒啥好講的,就來簡單講一下為什麼拆完之後分別做\(DFT\)、\(iDFT\),再分別取模還是對的吧(挺顯然的,其實)
我們把原係數拆分之後再做\(DFT\),在\(DFT\),重新組和,\(iDFT\)的過程中,是不會爆\(long\) \(long\) 的,也就是說我們求出的就是未經過取模的真實點值。
有了點值,再進行\(iDFT\),我們還原出的係數也就是將原係數進行拆分後的係數。
然後按照原來的比例相乘組合即可,在此時取模是沒有影響的,得到也即是真正的係數取模後的值。
- 程式碼:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1<<19|5; long double pi=acos(-1);
struct cp{
long double a,b;
cp() {a=b=0;}
cp(long double _a, long double _b) {a=_a,b=_b;}
cp operator +(cp x) const {return cp(a+x.a,b+x.b);}
cp operator -(cp x) const {return cp(a-x.a,b-x.b);}
cp operator *(cp x) const {return cp(a*x.a-b*x.b,a*x.b+b*x.a);}
cp Conj() {return cp(a,-b);}
};
cp p[N],q[N],a[N],b[N],c[N],d[N],kx,bx,ky,by,u,t;
int ret[N],L1,L2;
ll x[N],y[N],ans[N],mod,len,lg;
void FFT(cp *a,int n,long double opt){
int m=1;
for (int i=0;i<n;i++)
if (i<ret[i]) swap(a[i],a[ret[i]]);
for (int s=1;s<=lg;s++){
cp omega(cos(pi/m),opt*sin(pi/m));
m<<=1;
for (int k=0;k<n;k+=m){
cp now(1,0);
for (int j=0;j<m/2;j++){
u=a[k+j]; t=a[k+j+m/2]*now;
a[k+j]=u+t;
a[k+j+m/2]=u-t;
now=now*omega;
}
}
}
}
void mul(ll *x,ll *y,ll *ans,int n){
cp r(0.5,0),h(0,-0.5),o(0,1);
for (int i=0;i<n;i++)
ret[i]=(ret[i>>1]>>1)|((i&1)<<(lg-1));
for (int i=0;i<=L1;i++)
p[i]=cp(x[i]>>15,x[i]&32767);
for (int i=0;i<=L2;i++)
q[i]=cp(y[i]>>15,y[i]&32767);
FFT(p,n,1); FFT(q,n,1);
for (int i=0;i<n;i++){
int j=(n-i)&(n-1);
kx=cp(p[i]+p[j].Conj())*r;
bx=cp(p[i]-p[j].Conj())*h;
ky=cp(q[i]+q[j].Conj())*r;
by=cp(q[i]-q[j].Conj())*h;
a[i]=kx*ky; b[i]=kx*by; c[i]=bx*ky; d[i]=bx*by;
}
for (int i=0;i<n;i++){
p[i]=a[i]+(b[i]+c[i])*o;
q[i]=d[i];
}
FFT(p,n,-1); FFT(q,n,-1);
for (int i=0;i<n;i++){
ll a,b,c;
a=(ll)(p[i].a/n+0.5)%mod;
b=(ll)(p[i].b/n+0.5)%mod;
c=(ll)(q[i].a/n+0.5)%mod;
ans[i]=(((a<<30)%mod+(b<<15)%mod)+c)%mod;
if (ans[i]<0) ans+=mod;
}
}
int main(){
scanf("%d%d%lld",&L1,&L2,&mod);
for (int i=0;i<=L1;i++) scanf("%lld",&x[i]),x[i]%=mod;
for (int i=0;i<=L2;i++) scanf("%lld",&y[i]),y[i]%=mod;
for (len=1;len<=L1+L2+1;len<<=1) lg++;
mul(x,y,ans,len);
for (int i=0;i<L1+L2+1;i++) printf("%lld ",ans[i]);
}