【演算法】KMP演算法

簡介

KMP演算法由 Knuth-Morris-Pratt 三位科學家提出,可用於在一個 文本串 中尋找某 模式串 存在的位置。
本演算法可以有效降低在一個 文本串 中尋找某 模式串 過程的時間複雜度。(如果採取樸素的想法則複雜度是 \(O(MN)\)

這裡樸素的想法指的是枚舉 文本串 的起點,然後讓 模式串 從第一位開始一個個地檢查是否配對,如果不配對則繼續枚舉起點。

前置知識

真前綴
指字元串左部的任意子串(不包含自身),如 abcde 中的 a,ab,abc,abcd 都是真前綴但 abcde 不是。

真後綴
指字元串右部的任意子串(不包含自身),如 abcde 中的 e,de,cde,bcde 都是真後綴但 abcde 不是。

前綴函數
一個字元串中最長的、相等的真前綴與真後綴的長度, 如AABBAAA對應的前綴函數值是 \(2\)

原理

注意:在分析的時候,我們規定字元串的下標從 \(1\) 開始。

開始:
我們記掃描模式串的指針為j,而掃描文本串的指針為i,假設一開始i,j都在起點,然後讓它們一直下去直到完全匹配或者失配,比如:

j
ABCD

i
ABCDEFG

然後

 j
ABCD

 i
ABCDEFG

最後在此完成了一次匹配,類似地如果ABCD改為ABCC則在此失配。

   j
ABCD

   i
ABCDEFG

i,j運作模式如上。



KMP演算法就是,當模式串和文本串失配的時候,j指針從真後綴的末尾跳到真前綴的末尾,然後從真前綴後一位開始繼續匹配。(從而起到減少配對次數,這便是KMP演算法的核心原理)

結合例子解釋:

模式串: \(AABBAAA\)

文本串: \(AABBAABBAAA\)

j指針在最後一個A處失配。

      j
AABBAAA
      i
AABBAABBAAA

因為此時 以j為尾的前綴 所對應的前綴函數值是 \(2\) ,所以 j指針 跳到這裡:

 j
AABBAAA
      i
AABBAABBAAA

然後從下一位開始繼續配對:

  j
AABBAAA
      i
AABBAABBAAA

最後

      j
AABBAAA
          i
AABBAABBAAA

可以看出,KMP能夠有效減少配對次數。

實現

我們記模式串p文本串s

從上面的模擬中,我們發現需要預處理出一個數組(記之為next[]),它儲存模式串中前綴對應的前綴函數\(\pi()\),如對於字元串ABCABC

\(\pi(0)=0\) (因為什麼都沒有)
\(\pi(1)=0\)A甚至沒有真前綴真後綴
\(\pi(2)=0\)AB
\(\pi(3)=0\)ABC
\(\pi(4)=1\)ABCA
\(\pi(5)=2\)ABCAB
\(\pi(6)=3\)ABCABC

同樣地,我們發現如果用暴力樸素的想法來統計複雜度是 O(N^2) 不好,於是採用類似於上面的方法,只不過模式串配對的對象是自己罷了。

可以結合程式碼理解,並注意舉例,嘗試在紙上模擬這個過程。

for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 如果j指向元素的下一個元素會和當前配對位置失配,則j跳回去
        if(p[j+1]==p[i]) j++; //如果能夠配對上,j++
        next_[i]=j; //記錄當前位置的前綴函數π
}

完整程式碼:

#include<bits/stdc++.h>
using namespace std;

const int N=1e6+5;
char p[N],s[N];
int next_[N];

int main(){
    cin>>s+1>>p+1;

    int lenp=strlen(p+1),lens=strlen(s+1);
    // build next array
    for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 如果j指向元素的下一個元素會和當前配對位置失配,則j跳回去
        if(p[j+1]==p[i]) j++; //如果能夠配對上,j++
        next_[i]=j; //記錄當前位置的前綴函數π
    }

    for(int i=1,j=0;i<=lens;i++){
        while(j && p[j+1]!=s[i]) j=next_[j];
        if(p[j+1]==s[i]) j++;

        // if match
        if(j==lenp){
            j=next_[j];
            cout<<i-lenp+1<<endl;
        }
    }

    for(int i=1;i<=lenp;i++) cout<<next_[i]<<' ';
    cout<<endl;

    return 0;
}

複雜度

\(O(N+M)\)