KMP&Z函數詳解

2022-09-19 09:02:03

KMP

一些簡單的定義:

  • 真字首:不是整個字串的字首
  • 真字尾:不是整個字串的字尾

當然不可能這麼簡單的,來個重要的定義

  • 字首函數:
    給定一個長度為\(n\)的字串\(s\),其 \(字首函數\) 為一個長度為\(n\)的陣列\(\pi\),其中\(\pi_i\)表示
    1. 如果字串\(s[0...i]\)存在一對相等的真字首和真字尾\(s[0...k]~and~s[i-(k-1)...i]\),則\(\pi_i\)為這個真字首(真字尾)的長度\(k\)
    2. 如果有不止一對,則\(\pi_i\)為其中最長的一對
    3. 若沒有最長的,則\(\pi_i=0\)

簡而言之,\(\pi_i\)為字串\(s[0...i]\)最長相等的真字首和真字尾長度

特別的,規定\(\pi_0\)=0

舉個例子,對於字串\(s=\)"\(abcabcd\)",其字首函數\(\pi=\{0,0,0,1,2,3,0\}\)

考慮如何去求字首函數,最暴力的做法肯定使\(O(n^3)\)的,列舉字首位置、真字首(真字尾)的長度和真字首(真字尾)的每一位,程式碼就不放出來了(其實是沒寫)

優化一

我們發現當\(i+1\)時,\(\pi_i\)最多\(+1\),也就是說我們列舉\(\pi_i\)的長度時,上界為\(\pi{i-1}+1\),這樣複雜度可以被優化為\(O(n^2)\)

程式碼如下

void Getnex(string str){
    int len=str.size();
    nex[0]=0;
    for (int i=1;i<n;i++) {
        for (int j=nex[i-1];j>=0;j--) {
            if (str.substr(0,j)==str.substr(i-j+1,j)) {
                nex[i]=j;
                break;
            }
        }
    }
}

如何證明這個複雜度呢

我們發現每次進行一次\(substr\)操作,都意味著\(\pi_i\)的值都會\(-1\)

顯然有一種最壞的情況是,\(p_i\)的值先變成\(n-1\),然後再掉回\(0\)

容易看出這不會超過\(O(n)\)次,然後每次\(substr\)的複雜度為\(O(n)\)

所以總複雜度\(O(n^2)\)

優化二

我們在優化一中發現了一個性質,當\(s[i+1]==s[\pi_i]\)時,\(\pi_i=\pi_{i-1}+1\)

考慮把這個性質推下去,當\(s[i+1]!=s[\pi_i]\)

我們發現,我們要找的轉移點\(j\)是要滿足\(i-1\)的字首性質的

即滿足\(s[1...j]=s[i-(j-1)...i]\)

\(\pi_i\)不行時我們自然要去找下一個滿足性質的轉移點

很容易想到這東西不就是\(\pi_{\pi_i}嗎\)

由於真字首和真字尾相等,所以\(\pi_{\pi_i}\)必然滿足既是真字首的真字首又是真字尾的真字尾

可能有點繞,舉個例子\(s=\)"\(abaabaa\)"(下標從\(0\)開始)

當我們求\(\pi_6\)時,前面的\(\pi\)值為

\(\pi[0...5]=\{0,0,1,1,2,3\}\)

我們發現\(s[6]!=s[\pi_5=3]\)

那麼我們就需要找到下一個轉移點\(j\)滿足\(s[0...j]=s[5-(j+1)...5]\)

因為\(\pi_5=3\)所以\(s[0...5]\)的滿足字首性質的真字首(真字尾)為\(aba\)

所以對於字串"\(aba\)"滿足字首性質的真字首(真字尾)一定滿足\(s[0...5]\)的字首性質

所以轉移點即為\(\pi_3=\pi_{\pi_5}\)

至此,求字首函數便可以優化成\(O(n)\)

void Getnex(std::string S) {
    for (int i=2,j=0;i<S.size();i++) {
        while(j && S[j+1]!=S[i]) j=nex[j];
        if (S[j+1]==S[i]) j++;
        nex[i]=j;   
    }
}

前話到此完結,接下來是真正的\(KMP\)

\(KMP\)是對於字首函數的典型運用

舉個例子,給定一個文字\(t\)和字串\(s\),我們嘗試求出\(s\)\(t\)中的所有出現

我們記\(n,m\)\(s\)\(t\)的長度

我們構造一個字串為\(s+\)'\(.\)'\(+t\),其中\(.\)為不在\(s,t\)中出現的分隔符

計算出這個字串的字首函數,考慮這個字首函數除去前\(n+1\)個值意味著什麼

根據定義,\(\pi_i\)為右端點為\(i\),且為一個字首的最長真子串長度

且由於有分割符的存在,\(\pi\)不可能超過\(n\)

\(\pi_i=n\)時,則意味著\(s\)\(t\)中完整出現一次,其右端點為\(i\)

因此\(KMP\)可以在\(O(n+m)\)的複雜度內解決問題

void KMP(std::string S,std::string T) {
    for (int i=1,j=0;i<S.size();i++) {
        while(j && T[j+1]!=S[i]) j=nex[j];
        if (T[j+1]==S[i]) j++;
        if (j==m-1) {
            std::cout<<i-m+2<<std::endl;
            j=nex[j];
        }
    }
}

字串的週期

定義:

  • 對於字串\(s\),若存在\(p\)滿足\(s[i]=s[i+p](i\in[0,|s|-p-1])\),則\(p\)\(s\)的週期

  • 對於字串\(s\),若存在\(r\)滿足\(s\)長度為\(r\)的字首和長度為\(r\)的字尾相等,則稱\(s\)長度為\(r\)的字首是\(s\)\(border\)

由這兩個定義不難看出\(|s|-r\)\(s\)的週期

根據字首函數的定義我們可以得出\(s\)所有\(border\)長度,即\(\pi_{n-1},\pi_{\pi_{n-1}},...\)

所以我們可以在\(O(n)\)的時間複雜度內求出\(s\)的所有周期

其中最小週期為\(n-\pi_{n-1}\)

統計每個字首的出現次數

以下預設字串下標從\(1\)開始

主要是兩種問題,一個是求\(s\)的字首在\(s\)中的出現次數,另一個是求\(s\)的字首在另一個字串\(t\)中的出現次數

考慮位置\(i\)的字首函數值\(\pi_i\),根據定義,其意味著字串\(s\)一個長度為 的字首在位置\(i\)出現並以\(i\)為右端點,同時不存在一個更長的字首滿足前述定義。

與此同時,更短的字首可能以該位置為右端點。

容易看出,我們遇到了在計算字首函數時已經回答過的問題:給定一個長度為\(j\)的字首,同時其也是一個右端點位於\(i\)的字尾,下一個更小的字首長度\(k<j\)是多少?該長度的字首需同時也是一個右端點為\(i\)的字尾。

因此以位置\(i\)為右端點,有長度為\(\pi_i\)的字首,有長度為\(\pi_{\pi_i}\)的字首,等等,直到長度變為0。

故而我們可以通過下述方式計算答案。

void Getcnt(std::string str) {
    for (int i=1;i<str.size();i++) ans[nex[i]]++;
    for (int i=str.size()-1;i>0;i--) ans[nex[i]]+=ans[nex[i]];
    for (int i=1;i<str.size();i++) ans[i]++;
}

例題:CF432D Prefixes and Suffixes

題意:

給你一個長度為n的長字串,「完美子串」既是它的字首也是它的字尾,求「完美子串」的個數且統計這些子串的在長字串中出現的次數

模板題,直接寫就行

#include <ctime>
#include <cstdio>
#include <iostream>
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)

const int maxn=1e5+5;
int n,nex[maxn],ans[maxn];
std::string str;

void chkmax(int &x,int y) {if (x<y) x=y;}
void chkmin(int &x,int y) {if (x>y) x=y;}
int read() {
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while(ch<='9' && ch>='0') {x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}

void Getnex(std::string s) {
    for (int i=2,j=0;i<s.size();i++) {
        while(j && s[j+1]!=s[i]) j=nex[j];
        if (s[j+1]==s[i]) j++;
        nex[i]=j;
    }
}
void out(int i,int cnt) {
    if (!i) {std::cout<<cnt<<std::endl;return ;}
    out(nex[i],cnt+1);
    std::cout<<i<<' '<<ans[i]<<std::endl;
}

int main() {
    std::cin>>str;
    str=" "+str;
    Getnex(str);
    // for (int i=1;i<str.size();i++) std::cout<<nex[i]<<' ';puts("");
    for (int i=1;i<str.size();i++) ans[nex[i]]++;
    for (int i=str.size();i>=1;i--) ans[nex[i]]+=ans[i];
    for (int i=1;i<str.size();i++) ans[i]++;
    out(str.size()-1,0);
    return 0;
}

Z函數(擴充套件KMP)

定義:

對於一個字串\(s\)\(z[i]\)表示\(s\)\(s[i...n-1]\)\(LCP\)(最長公共字首)的長度,\(z\)則被稱為\(s\)\(Z\)函數

我們在計算的時候,可以通過前面已知的\(z\)來計算

對於\(i\),我們稱區間\([i,i+z[i]-1]\)\(i\)的匹配段

我們在演演算法過程中維護右端點最靠右的匹配段,記作\([l,r]\)

則有\([l,r]\)\(s\)的字首,並且在計算\(z[i]\)時我們保證\(l\leq i\)

演演算法流程

最開始時,\(l=r=0\)

在計算\(z[i]\)的過程中

  • \(i\leq r\),則有\(s[i,r]=s[i-l,r-l]\),所以\(z[i]\ge min(z[i-l,r-i+1])\)

    • \(z[i-l]<r-i+1\),則\(z[i]=z[i-l]\)
    • \(z[i-l]\ge r-i+1\),我們令\(z[i]=r-i+1\),然後暴力向後列舉字元
  • \(i>r\),我們直接暴力從\(s[i]\)開始比較,求出\(z[i]\)

  • 求出\(z[i]\)後,還要更新\(l,r\)

void GetZ(std::string s) {
    int l=0,r=0;
    for (int i=1;i<s.size();i++) {
        if (i<=r && z[i-l]<r-i+1) z[i]=z[i-l];
        else {
            z[i]=std::max(0,r-i+1);
            while(i+z[i]<s.size() && s[z[i]]==s[i+z[i]]) z[i]++;
        }
        if (i+z[i]-1>r) {l=i;r=i+z[i]-1};
    }
}

複雜度分析

對於內層的\(while\),每次執行都會使\(r\)向後移動至少一位,而\(r<n-1\),所以總共最多做\(n\)

加上外層的\(for\),總複雜度\(O(n)\)