FFT 學習筆記(自認為詳細)

引入

什麼是 \(\text{FFT}\)
反正我看到 \(\text{wiki}\) 上是一堆奇怪的東西。

快速傅里葉變換(英語:Fast Fourier Transform, FFT),是快速計算序列的離散傅里葉變換(DFT)或其逆變換的方法。傅里葉分析將訊號從原始域(通常是時間或空間)轉換到頻域的表示或者逆過來轉換。FFT會通過把DFT矩陣分解為稀疏(大多為零)因子之積來快速計算此類變換。—— \(\text{wikipedia}\)

反正我沒腦子我看不懂。
對我來說,\(\text{FFT}\) 就是能把多項式乘法從 \(O(n^2)\) 變成 \(O(n\log n)\) 的神仙玩意。

正文

係數表示法和點值表示法

對於係數表示法,就是用多項式的係數來表示這個多項式。
比如說:

\[f(x)=a_1x^3+a_2x^2+a_3x+a_4\Leftrightarrow f(x)=\{a_1,a_2,a_3,a_4\}
\]

那麼對於點值表示法,相對應的就是用該函數上的若干個點表示多項式。
學過小學數學的同學們一定知道:\(n+1\) 個點確定一個 \(n\) 次多項式。
證明的話可以考慮數學歸納法。/xyx
同樣舉一個例子,點值表示法是這樣的:

\[f(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n⇔f(x)=\{(x_0,y_0),(x_1,y_1),(x_2,y_2),\cdots,(x_n,y_n)\}\\
\]

上面講到要把係數表示法轉換成點值表示法。那麼這是為什麼呢?
下面就先來展示一下點值表示法的多項式乘法:

\[f(x)=\{(x_0,f(x_0)),(x_1,f(x_1)),(x_2,f(x_2)),\cdots,(x_n,f(x_n))\}\\
g(x)=\{(x_0,g(x_0)),(x_1,g(x_1)),(x_2,g(x_2)),\cdots,(x_n,g(x_n))\}\\
F(x) = f(x)\times g(x)\\
F(x)=\{(x_0,g(x_0)\times f(x_0)),(x_1,g(x_1)\times f(x_1)),\cdots,(x_n,g(x_n)\times f(x_n))\}
\]

複數

\(複數 = 實數 + 虛數\)
實在一點吧,直接上乾貨,我們定義 :

\[i=\sqrt{-1}
\]

這樣我們就可以表示我們在實數範圍內不能表示的數了。
那麼如何表示一個複數呢:

\[Num=a+bi\ \ (a,b \in \mathbb{R})
\]

接著我們把 \(Num=a+bi\) 看成一個函數,把 \(a\)\(b\) 分別對應 \(x\) 軸和 \(y\) 軸。
就可以得到複數平面,大概長這樣:
image
其中橫坐標是實數軸,縱坐標是虛數軸,這樣就可以把每個虛數看為一個向量了。
對應的,虛數可以用普通坐標和極坐標表示:

\[(x,y)\quad和\quad (r,\theta)
\]

下面給出兩個複數相乘的意義:

\[\begin{split}&\quad
(a+bi)\times(c+di)\\&=
ac+adi+bci+bdi^2\\
&=(ac+bd)+(ad+bc)i
\end{split} \\\quad \\
(r_1,\theta_1)\times(r_2,\theta_2)=(r_1\times r_2,\theta_1+\theta_2)
\]

\(\tt DFT\) (離散傅里葉變換)

現在已經介紹完了點值表示法複數的相關知識,接下來就是乾貨部分了。

上面我們已經通過這樣的例子說明了點值表示法算多項式乘法的方便。
接下來我們來看怎麼先把多項式從係數表示法轉換為點值表示法,這種過程叫 \(\text{DFT}\)

所謂的點值表示法,也就是在 \(n\) 多項式上取 \(n+1\) 個點,來進行表示。
形式化的,可以表示成這樣:

\[F(x)=a_0+a_1x+a_2x^2+\cdots+a_{n-1}x^{n-1}+a_nx^n\\
\rightarrow F(x)=\{(x_0,F(x_0)),(x_1,F(x_1)),(x_2,F(x_2)),\cdots,(x_n,F(x_n))\}
\]

然後可以驚喜的發現,隨便帶幾個 \(x_i\) 進去在算算 \(F(x_i)\) 就好了。
但是如果你小學畢業了,你就可以發現這樣的話不如直接 \(O(n^2)\) 暴力。

所以該怎麼辦?
我們猜想是否存在一些 \(x\) 使得 \(x^n\ (n\in \tt Z^+)\) 的結果都是 \(1\)
這看上去是一個非常好的思路,但是這樣的數有多少個呢?
我能脫口說出兩個 \(1\)\(-1\) ,想一想可以發現其實 \(i\)\(-i\) 也都可以。

但是經過認真思考(看題解)可以發現下圖的單位圓上所有的點都滿足條件。

為了方便,我們在取這 \(n\) 個點時會把這個單位圓平分。

我們從 \((1, 0)\) 這個點開始,按照逆時針的方向從 \(0\) 開始進行編號,形如 \(\omega_n^k\)
其中 \(n\) 表示一共選擇了 \(n\) 個點,\(k\) 表示當前點的編號。

由我們之前介紹的複數乘法的 模長相乘,度數相加

\[(r_1,\theta_1)\times(r_2,\theta_2)=(r_1\times r_2,\theta_1+\theta_2)
\]

並且結合單位圓的性質(所有的點到原點的距離為 \(1\))。
可以得到由 \(\omega_n^1\) 轉換到 \(\omega_n^k\) 的公式:

\[(\omega_n^1)^k=\omega_n^k
\]

我們稱 \(\omega_n^1\)\(n\) 次單位根。
所以可以發現,我們直接帶入 \(\omega_n^i\) 就可以了。

單位根的一些有用的性質

在了解一切的性質之前,我們要先知道單位根 \(\omega_n^i\) 如何表示:

\[\omega_n^k=\cos\frac{k}{n}2\pi+i\times\sin\frac{k}{n}2\pi
\]

這東西的證明你直接照著單位圓上畫一個點然後三角函數入門知識即可。

性質一

\[\omega_n^k=\omega_{2n}^{2k}
\]

證明的話直接照著上面給出的式子套即可,然後發現可以約分。
那我認為進一步的可以得到:

\[\omega_n^k=\omega_{Pn}^{Pk} \quad (p\in \tt Z^+)
\]

很顯然不過好像沒有什麼大用。

性質二

\[\omega_n^{k+\frac{n}{2}}=-\omega_n^k
\]

證明的話稍微寫一下吧:

\[\omega_n^k=\cos\frac{k}{n}2\pi+i\times\sin\frac{k}{n}2\pi\\
\begin{split}
\omega_n^{k+\frac{n}{2}}&=
\cos\frac{k+\frac{n}{2}}{n}2\pi+i\times\sin\frac{k+\frac{n}{2}}{n}2\pi\\
&=\cos(\frac{k}{n}2\pi+\pi)+i\times\sin(\frac{k}{n}2\pi+\pi)
\end{split}
\]

都化成這一步了就不在進行下一步證明,還看不懂的建議重修初中數學。

性質三

\[\omega_n^0=\omega_n^n
\]

比較憨,我就不講為什麼了。

\(\tt FFT\) (快速傅里葉變換)

他來了,他來了,等到現在他終於來了。。。。

之前講到我們直接帶入 \(\omega_n^i\) 來計算點值。
是的,我認為這種方法高效,巧妙,逼格高,體現了人類智慧。
但是等等,雖然算係數的過程免掉了,但是對於每一個 \(\omega_n^i\) 我們還是要 \(O(n)\) 算結果啊。
然後我搬來搬手指算了一下,發現一共有 \(n\)\(\omega_n^i\) 的值,然後就又 \(O(n^2)\) 了。
所以我們該怎麼辦?

認真地看看題解,發現可以從分治的角度入手。
注意:以下的內容保證 \(n\)\(2\) 的整數次方。
我們設一個多項式:

\[\begin{split}
F(x)&=\sum_{i=0}^{n-1}a_ix^i\\
&=a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1}
\end{split}
\]

然後想辦法把 \(F(x)\) 分成兩個部分。
這裡採用的方法是按照 \(F(x)\) 下標的奇偶性分成兩個部分。

\[\begin{split}
F(x)
&=a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1}\\
&=(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})+x(a_1+a_3x^2+\cdots+a_{n-1}x^{n-2})
\end{split}
\]

接下來我們發現拆出來的這兩個多項式的結構是一模一樣的。
我們再分別設這兩個多項式為 \(F_1(x)\)\(F_2(x)\)

\[F_1(x)=a_0+a_2x^2+\cdots+a_{n-2}x^{n-2}\\
F_2(x)=a_1+a_3x^2+\cdots+a_{n-1}x^{n-2}\\
F(x)=F_1(x)+xF_2(x)
\]

發現這樣的係數不連續,沒有那麼完美,於是我們再變化一下。

\[F_1(x)=a_0+a_2x^1+\cdots+a_{n-2}x^{\frac{n}{2}-1}\\
F_2(x)=a_1+a_3x^1+\cdots+a_{n-1}x^{\frac{n}{2}-1}\\
F(x) = F_1(x^2)+xF_2(x^2)
\]

此時看可以發現這樣的形式非常的優美。
接下來就是直接帶入 \(\omega_n^i\) 的操作了。
我們接著設 \(k<\frac{n}{2}\) 然後把 \(\omega_n^k\) 直接帶入。

\[\begin{split}
F(\omega_n^k) &= F_1((\omega_n^k)^2)+\omega_n^kF_2((\omega_n^k)^2)\\
&=F_1(\omega_{n}^{2k})+\omega_n^kF_2(\omega_{n}^{2k})\\
&=F_1(\omega_{\frac{n}{2}}^{k})+\omega_n^kF_2(\omega_{\frac{n}{2}}^{k})
\end{split}
\]

第一步直接帶入,有問題的話小學建議重修。
第二步的話我之前寫過,公式是這樣的:

\[(\omega_n^1)^k=\omega_n^k
\]

當然,在這裡運用是具有普遍性的,有問題的話直接推一下。
至於第三步,直接算比例我認為會更加快速一些。
對於 \(F(\omega_n^{k+\frac{n}{2}})\) 直接帶入:

\[\begin{split}
F(\omega_n^{k+\frac{n}{2}}) &= F_1((\omega_n^{k+\frac{n}{2}})^2)+\omega_n^{k+\frac{n}{2}}F_2((\omega_n^{k+\frac{n}{2}})^2)\\
&=F_1(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}F_2(\omega_n^{2k+n})\\
&=F_1(\omega_n^{2k}\omega_n^n)-\omega_n^{k}F_2(\omega_n^{2k}\omega_n^n)\\
&=F_1(\omega_n^{2k})-\omega_n^{k}F_2(\omega_n^{2k})\\
&=F_1(\omega_{\frac{n}{2}}^{k})-\omega_{n}^{k}F_2(\omega_{\frac{n}{2}}^{k})\\
\end{split}
\]

每一步一一介紹比較麻煩,大家直接手頭一下或者翻翻前面的公式。
觀察第一個式子和第二個式子,發現唯一不一樣的地方就是符號了。
然後直接分治求解即可,時間複雜度 \(O(n\log n)\)

\(\tt IFF\) (快速傅里葉逆變換)

就是把點值表示法轉換成為我們要的係數表示法。
這裡給出結論,證明的話屬實比較噁心,所以我就不證明了。

一個多項式在分治的過程中乘上單位根的共軛複數,分治完的每一項除以 \(n\) 即為原多項式的每一項係數

也就是再做一遍 \(\tt FFT\) 輸出時每一位除以 \(n\) 就可以了。

程式碼實現及其優化

Code 複數類型封裝

struct cp {
  double x, y;
  cp (double xx = 0, double yy = 0) {x = xx; y = yy;};

  friend cp operator +(cp p, cp q) {return cp(p.x + q.x, p.y + q.y);}
  friend cp operator -(cp p, cp q) {return cp(p.x - q.x, p.y - q.y);}
  friend cp operator *(cp p, cp q) {return cp(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y);}
}a[N], b[N];

Code 無優化

不是我寫的程式碼,反正就是照著之前的公式模擬,看看就好了。

點擊查看程式碼
#include<complex>
#define cp complex<double>

void fft(cp *a, int n, int inv) //inv是取共軛複數的符號
{
  if (n == 1)return;
  int mid = n / 2;
  static cp b[MAXN];
  fo(i, 0, mid - 1)b[i] = a[i * 2], b[i + mid] = a[i * 2 + 1];
  fo(i, 0, n - 1)a[i] = b[i];
  fft(a, mid, inv), fft(a + mid, mid, inv); //分治
  fo(i, 0, mid - 1)
  {
    cp x(cos(2 * pi * i / n), inv * sin(2 * pi * i / n)); //inv取決是否取共軛複數
    b[i] = a[i] + x * a[i + mid], b[i + mid] = a[i] - x * a[i + mid];
  }
  fo(i, 0, n - 1)a[i] = b[i];
}

cp a[MAXN], b[MAXN];
int c[MAXN];
fft(a, n, 1), fft(b, n, 1); //1係數轉點值
fo(i, 0, n - 1)a[i] *= b[i];
fft(a, n, -1); //-1點值轉係數
fo(i, 0, n - 1)c[i] = (int)(a[i].real() / n + 0.5); //注意精度

注意:\(\tt FFT\) 之前要先把 \(n\) 調成 \(2\) 的整數次冪。
很顯然上面的那個是連模板題都過不了的。
所以在這裡我們才需要去考慮怎麼去優化 \(\tt FFT\)

觀察一下原序列和反轉後的序列,需要求的序列實際是原序列下標的二進位反轉!
因此我們對序列按照下標的奇偶性分類的過程其實是沒有必要的。
這樣我們可以 \(O(n)\) 的利用某種操作得到我們要求的序列,然後不斷向上合併就好了。
—— \(\tt luogu\) 某題解

Code 有優化,可過

點擊查看程式碼
#include <bits/stdc++.h>

#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)

#define Enter putchar('\n')
#define quad putchar(' ')

#define N 3000005

namespace IO {

template <class T>
inline void read(T &a);
template <class T, class ...rest>
inline void read(T &a, rest &...x);

template <class T>
inline void write(T x);

}

struct cp {
  double x, y;
  cp (double xx = 0, double yy = 0) {x = xx; y = yy;};

  friend cp operator +(cp p, cp q) {return cp(p.x + q.x, p.y + q.y);}
  friend cp operator -(cp p, cp q) {return cp(p.x - q.x, p.y - q.y);}
  friend cp operator *(cp p, cp q) {return cp(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y);}
}a[N], b[N];

const double pi = acos(-1.0);

int n1, n2, n, rev[N], c[N];

inline void FFT(cp *a, int n, int inv) {
  int bit = 0;
  while ((1 << bit) < n) bit++;
  for (int i = 1; i < n; ++i) {
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    if (i < rev[i])
      std::swap(a[rev[i]], a[i]);
  }
  for (int mid = 1; mid < n; mid <<= 1) {
    cp temp(cos(pi / mid), inv * sin(pi / mid));
    for (int i = 0; i < n; i += mid * 2) {
      cp omega(1,0);
      for (int j = 0; j < mid; ++j, omega = omega * temp) {
        cp x = a[i + j], y = omega * a[i + j + mid];
        a[i + j] = x + y;
        a[i + j + mid] = x - y;
      }
    }
  }
}

signed main(void) {
  // file("P3803");
  IO::read(n1, n2);
  n = std::max(n1, n2);
  for (int i = 0, num; i <= n1; ++i) IO::read(num), a[i].x = num;
  for (int i = 0, num; i <= n2; ++i) IO::read(num), b[i].x = num;
  n = n1 + n2;
  for (int i = 0; i <= 30; ++i)
    if ((1 << i) > n) {
      n = (1 << i);
      break;
    }

  FFT(a, n, 1); FFT(b, n, 1);
  for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
  FFT(a, n, -1);
  for (int i = 0; i <= n1 + n2; ++i)
    c[i] = (int)(a[i].x / n + 0.5);
  
  for (int i = 0; i <= n1 + n2; ++i)
    IO::write(c[i]), quad;
  Enter;
}

namespace IO {

template <class T>
inline void read(T &a) {
  T s = 0, t = 1;
  char c = getchar();
  while ((c < '0' || c > '9') && c != '-')
    c = getchar();
  if (c == '-')
    c = getchar(), t = -1;
  while (c >= '0' && c <= '9')
    s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
  a = s * t;
}
template <class T, class ...rest>
inline void read(T &a, rest &...x) {
  read(a);
  read(x...);
}

template <class T>
inline void write(T x) {
  if (x == 0) putchar('0');
  if (x < 0) putchar('-'), x = -x;
  int top = 0, sta[55] = {0};
  while (x) 
    sta[++top] = x % 10, x /= 10;
  while (top)
    putchar(sta[top] + '0'), top--;
  return ;
}

}

在這裡推薦 某知乎專欄 ,把 \(\tt FFT\) 優化講的很清楚。

\(\tt NTT\) 還是會看的,但是 \(\tt FFT\) 把我給些虛脫了。。。