由於新文章的做法與舊文章不同, 因此 KMP 演演算法仍保留舊文章, 且經過模板題測驗, 新的做法明顯慢於舊的做法, 但是, 新做法更好理解.
字首 是指從串首開始到某個位置 \(i\) 結束的一個特殊子串.
真字首 指除了 \(S\) 本身的 \(S\) 的字首.
舉例來說, 字串 abcabeda
的所有字首為 {a, ab, abc, abca, abcab, abcabe, abcabed, abcabeda}
, 而它的真字首為 {a, ab, abc, abca, abcab, abcabe, abcabed}
.
字尾 是指從某個位置 \(i\) 開始到整個串末尾結束的一個特殊子串.
真字尾 指除了 \(S\) 本身的 \(S\) 的字尾.
舉例來說, 字串 abcabeda
的所有字尾為 {a, da, eda, beda, abeda, cabeda, bcabeda, abcabeda}
, 而它的真字尾為 {a, da, eda, beda, abeda, cabeda, bcabeda}
.
定義: 給定一個長度為 \(n\) 的字串 \(s\), 其字首函數被定義為一個長度為 \(n\) 的陣列 nxt
. 其中 nxt[i]
是子串 s[0 ~ i]
最長的相等的真字首和真字尾的長度.
用數學語言描述如下:
特別地, nxt[0] = 0
, 因為不存在真字首和真字尾.
舉例來說, 對於字串 aabaaab
,
nxt[0] = 0
, a
沒有真字首和真字尾.
nxt[1] = 1
, aa
只有一對相等的真字首和真字尾: a
, 長度為 \(1\).
nxt[2] = 0
, aab
沒有相等的真字首和真字尾.
nxt[3] = 1
, aaba
只有一對相等的真字首和真字尾: a
, 長度為 \(1\).
nxt[4] = 2
, aabaa
相等的真字首和真字尾有 a
, aa
, 最長的長度為 \(2\).
nxt[5] = 2
, aabaaa
相等的真字首和真字尾有 a
, aa
, 最長的長度為 \(2\).
nxt[6] = 3
, aabaaab
相等的真字首和真字尾只有 aab
, 最長的長度為 \(3\).
cin >> s1;
len1 = s1.length();
for (int i = 1; i < len1; ++ i) {
for (int j = i; j; -- j) {
if (s1.substr(0, j) == s1.substr(i - (j - 1), j)) {
nxt[i] = j;
break ;
}
}
}
第一個重要的觀察是 相鄰的字首函數值至多增加 \(1\).
參照下圖所示, 只需如此考慮: 當取一個儘可能大的 nxt[i + 1]
時, 必然要求新增的 s[i + 1]
也與之對應的字元匹配, 即 s[i + 1] = s[nxt[i]]
, 此時 s[i + 1] = s[i] + 1
.
所以當移動到下一個位置時, 字首函數的值要麼增加一, 要麼維持不變, 要麼減少.
當 s[i+1] != s[nxt[i]]
時, 我們希望找到對於子串 s[0 ~ i]
, 僅次於 nxt[i]
的第二長度 \(j\), 使得在位置 \(i\) 的字首性質仍得以保持, 也即 s[0 ~ (j - 1)] = s[(i - j + 1) ~ i]
:
如果我們找到了這樣的長度 \(j\), 那麼僅需要再次比較 s[i + 1]
和 s[j]
. 如果它們相等, 那麼就有 nxt[i + 1] = j + 1
. 否則, 我們需要找到子串 s[0 ~ i]
僅次於 \(j\) 的第二長度 \(j_{2}\), 使得字首性質得以保持, 如此反覆, 直到 \(j = 0\). 如果 s[i + 1] != s[0]
, 則 nxt[i + 1] = 0
.
觀察上圖可以發現, 因為 s[0 ~ nxt[i] - 1] = s[i - nxt[i] + 1 ~ i]
, 所以對於 s[0 ~ i]
的第二長度 \(j\), 有這樣的性質:
s[0 ~ j - 1] = s[i - j + 1 ~ i]= s[nxt[i] - j ~ nxt[i] - 1]
也就是說 \(j\) 等價於子串 s[nxt[i] - 1]
的字首函數值 (你可以把上面的 \(i\) 換成 nxt[i] - 1
), 即 j = nxt[nxt[i] - 1]
. 同理, 次於 \(j\) 的第二長度等價於 s[j - 1]
的字首函數值.
cin >> s1;
len1 = s1.length();
for (int i = 1; i < len1; ++ i) {
int j = nxt[i - 1];
while (j && s1[i] != s1[j]) {
j = nxt[j - 1];
}
if (s1[i] == s1[j]) {
++ j;
}
nxt[i] = j;
}
給定一個文字 \(t\) 和一個字串 \(s\), 我們嘗試找到並展示 \(s\) 在 \(t\) 中的所有出現.
為了簡便起見, 我們用 \(n\) 表示字串 \(s\) 的長度, 用 \(m\) 表示文字 \(t\) 的長度.
我們構造一個字串 \(s\) + #
+ \(t\), 其中 #
為一個既不出現在 \(s\) 中也不出現在 \(t\) 中的分隔符.
接下來計算該字串的字首函數. 現在考慮該字首函數除去最開始 \(n + 1\) 個值 (即屬於字串 \(s\) 和分隔符的函數值) 後其餘函數值的意義. 根據定義,nxt[i]
為右端點在 \(i\) 且同時為一個字首的最長真子串的長度, 具體到我們的這種情況下, 其值為與 \(s\) 的字首相同且右端點位於 \(i\) 的最長子串的長度. 由於分隔符的存在, 該長度不可能超過 \(n\). 而如果等式 nxt[i] = n
成立, 則意味著 \(s\) 完整出現在該位置 (即其右端點位於位置 \(i\)). 注意該位置的下標是對字串 \(s\) + #
+ \(t\) 而言的.
因此如果在某一位置 \(i\) 有 nxt[i] = n
成立, 則字串 \(s\) 在字串 \(t\) 的 \(i - (n - 1) - (n + 1) = i - 2n\) 處出現.
正如在字首函數的計算中已經提到的那樣, 如果我們知道字首函數的值永遠不超過一特定值, 那麼我們不需要儲存整個字串以及整個字首函數, 而只需要二者開頭的一部分. 在我們這種情況下這意味著只需要儲存字串 \(s\) + #
以及相應的字首函數值即可. 我們可以一次讀入字串 \(t\) 的一個字元並計算當前位置的字首函數值.
因此 Knuth–Morris–Pratt 演演算法(簡稱 KMP 演演算法)用 \(O_{n + m}\) 的時間以及 \(O_{n}\) 的記憶體解決了該問題.
/*
The code was written by yifan, and yifan is neutral!!!
*/
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template<typename T>
inline T read() {
T x = 0;
bool fg = 0;
char ch = getchar();
while (ch < '0' || ch > '9') {
fg |= (ch == '-');
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
return fg ? ~x + 1 : x;
}
const int N = 1e6 + 5;
int nxt[N << 1];
char s1[N], s2[N], cur[N << 1];
inline void get_nxt(char* s) {
int len = strlen(s);
for (int i = 1; i < len; ++ i) {
int j = nxt[i - 1];
while (j && s[i] != s[j]) {
j = nxt[j - 1];
}
if (s[i] == s[j]) {
++ j;
}
nxt[i] = j;
}
}
int main() {
cin >> s1 >> s2;
scanf("%s%s", s1, s2);
strcpy(cur, s2);
strcat(cur, "#");
strcat(cur, s1);
get_nxt(cur);
int l1 = strlen(s1), l2 = strlen(s2);
for (int i = l2 + 1; i <= l1 + l2; ++ i) {
if (nxt[i] == l2) {
cout << i - 2 * l2 + 1 << '\n';
}
}
for (int i = 0; i < l2; ++ i) {
cout << nxt[i] << ' ';
}
return 0;
}