【算法】KMP算法

简介

KMP算法由 Knuth-Morris-Pratt 三位科学家提出,可用于在一个 文本串 中寻找某 模式串 存在的位置。
本算法可以有效降低在一个 文本串 中寻找某 模式串 过程的时间复杂度。(如果采取朴素的想法则复杂度是 \(O(MN)\)

这里朴素的想法指的是枚举 文本串 的起点,然后让 模式串 从第一位开始一个个地检查是否配对,如果不配对则继续枚举起点。

前置知识

真前缀
指字符串左部的任意子串(不包含自身),如 abcde 中的 a,ab,abc,abcd 都是真前缀但 abcde 不是。

真后缀
指字符串右部的任意子串(不包含自身),如 abcde 中的 e,de,cde,bcde 都是真后缀但 abcde 不是。

前缀函数
一个字符串中最长的、相等的真前缀与真后缀的长度, 如AABBAAA对应的前缀函数值是 \(2\)

原理

注意:在分析的时候,我们规定字符串的下标从 \(1\) 开始。

开始:
我们记扫描模式串的指针为j,而扫描文本串的指针为i,假设一开始i,j都在起点,然后让它们一直下去直到完全匹配或者失配,比如:

j
ABCD

i
ABCDEFG

然后

 j
ABCD

 i
ABCDEFG

最后在此完成了一次匹配,类似地如果ABCD改为ABCC则在此失配。

   j
ABCD

   i
ABCDEFG

i,j运作模式如上。



KMP算法就是,当模式串和文本串失配的时候,j指针从真后缀的末尾跳到真前缀的末尾,然后从真前缀后一位开始继续匹配。(从而起到减少配对次数,这便是KMP算法的核心原理)

结合例子解释:

模式串: \(AABBAAA\)

文本串: \(AABBAABBAAA\)

j指针在最后一个A处失配。

      j
AABBAAA
      i
AABBAABBAAA

因为此时 以j为尾的前缀 所对应的前缀函数值是 \(2\) ,所以 j指针 跳到这里:

 j
AABBAAA
      i
AABBAABBAAA

然后从下一位开始继续配对:

  j
AABBAAA
      i
AABBAABBAAA

最后

      j
AABBAAA
          i
AABBAABBAAA

可以看出,KMP能够有效减少配对次数。

实现

我们记模式串p文本串s

从上面的模拟中,我们发现需要预处理出一个数组(记之为next[]),它储存模式串中前缀对应的前缀函数\(\pi()\),如对于字符串ABCABC

\(\pi(0)=0\) (因为什么都没有)
\(\pi(1)=0\)A甚至没有真前缀真后缀
\(\pi(2)=0\)AB
\(\pi(3)=0\)ABC
\(\pi(4)=1\)ABCA
\(\pi(5)=2\)ABCAB
\(\pi(6)=3\)ABCABC

同样地,我们发现如果用暴力朴素的想法来统计复杂度是 O(N^2) 不好,于是采用类似于上面的方法,只不过模式串配对的对象是自己罢了。

可以结合代码理解,并注意举例,尝试在纸上模拟这个过程。

for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 如果j指向元素的下一个元素会和当前配对位置失配,则j跳回去
        if(p[j+1]==p[i]) j++; //如果能够配对上,j++
        next_[i]=j; //记录当前位置的前缀函数π
}

完整代码:

#include<bits/stdc++.h>
using namespace std;

const int N=1e6+5;
char p[N],s[N];
int next_[N];

int main(){
    cin>>s+1>>p+1;

    int lenp=strlen(p+1),lens=strlen(s+1);
    // build next array
    for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 如果j指向元素的下一个元素会和当前配对位置失配,则j跳回去
        if(p[j+1]==p[i]) j++; //如果能够配对上,j++
        next_[i]=j; //记录当前位置的前缀函数π
    }

    for(int i=1,j=0;i<=lens;i++){
        while(j && p[j+1]!=s[i]) j=next_[j];
        if(p[j+1]==s[i]) j++;

        // if match
        if(j==lenp){
            j=next_[j];
            cout<<i-lenp+1<<endl;
        }
    }

    for(int i=1;i<=lenp;i++) cout<<next_[i]<<' ';
    cout<<endl;

    return 0;
}

复杂度

\(O(N+M)\)