「學習筆記」AC 自動機

2023-07-22 06:01:14

AC 自動機是 以 Trie 的結構為基礎,結合 KMP 的思想 建立的自動機,用於解決多模式匹配等任務。

Trie 的構建

這裡需要仔細解釋一下 Trie 的結點的含義,Trie 中的結點表示的是某個模式串的字首。我們在後文也將其稱作狀態。一個結點表示一個狀態,Trie 的邊就是狀態的轉移。

形式化地說,對於若干個模式串 \(s_1, s_2 \dots s_n\),將它們構建一棵字典樹後的所有狀態的集合記作 Q。

失配指標

個人感覺這裡是最難理解的。

AC 自動機利用一個 fail 指標來輔助多模式串的匹配。

狀態 \(u\) 的 fail 指標指向另一個狀態 \(v\),其中 \(v \in Q\),且 \(v\)\(u\) 的最長字尾(即在若干個字尾狀態中取最長的一個作為 fail 指標)。

只需要知道,AC 自動機的失配指標指向當前狀態的最長字尾狀態即可。

構建指標

構建 fail 指標,可以參考 KMP 中構造 Next 指標的思想。

考慮字典樹中當前的結點 \(u\)\(u\) 的父結點是 \(p\)\(p\) 通過字元 \(c\) 的邊指向 \(u\),即 \(trie[p,\mathtt{c}]=u\)。假設深度小於 \(u\) 的所有結點的 fail 指標都已求得。

  1. 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 存在:則讓 \(u\) 的 fail 指標指向 \(\text{trie}[\text{fail}[p],\mathtt{c}]\)。相當於在 \(p\)\(\text{fail}[p]\) 後面加一個字元 \(c\),分別對應 \(u\) 和 fail[u]。

  2. 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 不存在:那麼我們繼續找到 \(\text{trie}[\text{fail}[\text{fail}[p]],\mathtt{c}]\)。重複 \(1\) 的判斷過程,一直跳 fail 指標直到根結點。

  3. 如果真的沒有,就讓 fail 指標指向根結點。
    如此即完成了 \(\text{fail}[u]\) 的構建。

如此即完成了 \(\text{fail}[u]\) 的構建。

實現

定義

struct node {
    int fail;
    int tr[26];
    int End;
} ac[N];

fail 是失配指標,tr 是字典樹,End 是當前狀態是否為一個字串的結束。

插入

這裡就是最基本的字典樹插入操作。

void Insert(char* s) {
    int l = strlen(s), u = 0;
    for (int i = 0; i < l; ++ i) {
        if (ac[u].tr[s[i] - 'a'] == 0) {
            ac[u].tr[s[i] - 'a'] = ++ tot;
        }
        u = ac[u].tr[s[i] - 'a'];
    }
    ++ ac[u].End;
}

構建失敗指標

我們用佇列廣搜的方式來構建失敗指標,按照上面的步驟:

  • 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 存在:則讓 \(u\) 的 fail 指標指向 \(\text{trie}[\text{fail}[p],\mathtt{c}]\)。相當於在 \(p\)\(\text{fail}[p]\) 後面加一個字元 \(c\),分別對應 \(u\) 和 fail[u]。

  • 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 不存在:那麼我們繼續找到 \(\text{trie}[\text{fail}[\text{fail}[p]],\mathtt{c}]\)。重複 \(1\) 的判斷過程,一直跳 fail 指標直到根結點。

  • 如果真的沒有,就讓 fail 指標指向根結點。
    如此即完成了 \(\text{fail}[u]\) 的構建。

void get_fail() {
	queue<int> q;
	for (int i = 0; i < 26; ++ i) {
		if (ac[0].tr[i] != 0) {
			ac[ac[0].tr[i]].fail = 0;
			q.emplace(ac[0].tr[i]);
		}
	}
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (int i = 0; i < 26; ++ i) {
			if (ac[u].tr[i]) {
				ac[ac[u].tr[i]].fail = ac[ac[u].fail].tr[i];
				q.emplace(ac[u].tr[i]);
			} else {
				ac[u].tr[i] = ac[ac[u].fail].tr[i];
			}
		}
	}
}

查詢

這裡我們用模板題來說明。

查詢有多少個模式串出現過

P3808 【模板】AC 自動機(簡單版) - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

int ask(char* s) {
	int l = strlen(s);
	int u = 0, ans = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur && (~ac[cur].End); cur = ac[cur].fail) {
			ans += ac[cur].End;
			ac[cur].End = -1;
		}
	}
	return ans;
}

這裡給 End 打上標記,是為了防止重複搜到這一個模式串,然後重複加入了答案。

完整程式碼:

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 1e6 + 5;

int n, tot;
char s[N];

struct node {
	int fail;
	int tr[26];
	int End;
} ac[N];

void Insert(char* s) {
	int l = strlen(s), u = 0;
	for (int i = 0; i < l; ++ i) {
		if (ac[u].tr[s[i] - 'a'] == 0) {
			ac[u].tr[s[i] - 'a'] = ++ tot;
		}
		u = ac[u].tr[s[i] - 'a'];
	}
	++ ac[u].End;
}

void get_fail() {
	queue<int> q;
	for (int i = 0; i < 26; ++ i) {
		if (ac[0].tr[i] != 0) {
			ac[ac[0].tr[i]].fail = 0;
			q.emplace(ac[0].tr[i]);
		}
	}
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (int i = 0; i < 26; ++ i) {
			if (ac[u].tr[i]) {
				ac[ac[u].tr[i]].fail = ac[ac[u].fail].tr[i];
				q.emplace(ac[u].tr[i]);
			} else {
				ac[u].tr[i] = ac[ac[u].fail].tr[i];
			}
		}
	}
}

int ask(char* s) {
	int l = strlen(s);
	int u = 0, ans = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur && (~ac[cur].End); cur = ac[cur].fail) {
			ans += ac[cur].End;
			ac[cur].End = -1;
		}
	}
	return ans;
}

int main() {
	n = read<int>();
	for (int i = 1; i <= n; ++ i) {
		scanf("%s", s + 1);
		Insert(s + 1);
	}
	ac[0].fail = 0;
	get_fail();
	scanf("%s", s + 1);
	cout << ask(s + 1) << '\n';
	return 0;
}

查詢出現次數最多的模式串

P3796 【模板】AC 自動機(加強版) - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

這裡 End 儲存的不再是簡單的 \(1\) 了,而是當前這個狀態對應的模式串的編號,目的是最後輸出字串。

void Insert(string s, int num) {
	int u = 0, l = s.size();
	for (int i = 0; i < l; ++ i) {
		if (!ac[u].tr[s[i] - 'a']) {
			ac[u].tr[s[i] - 'a'] = ++ cnt;
			clr(cnt);
		}
		u = ac[u].tr[s[i] - 'a'];
	}
	ac[u].End = num;
}

for (int i = 1; i <= n; ++ i) {
	cin >> st[i];
	Insert(st[i], i);
	Ans[i].first = 0;
	Ans[i].second = i;
}

除了查詢和主函數,其他程式碼都是一樣的。

查詢程式碼:

void ask(char* s) {
	int l = strlen(s);
	int u = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur; cur = ac[cur].fail) {
			++ Ans[ac[cur].End].first;
		}
	}
}

這裡的 Ans 是定義的答案陣列,first 是記錄出現的次數,second 是該狀態的編號。

完整程式碼:

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 1e6 + 5;

int n, cnt;
char s[N];
string st[200];

struct node {
	int fail, End;
	int tr[26];
} ac[N];

pair<int, int> Ans[N];

void clr(int u) {
	for (int i = 0; i < 26; ++ i) {
		ac[u].tr[i] = 0;
	}
	ac[u].fail = ac[u].End = 0;
}

void Insert(string s, int num) {
	int u = 0, l = s.size();
	for (int i = 0; i < l; ++ i) {
		if (!ac[u].tr[s[i] - 'a']) {
			ac[u].tr[s[i] - 'a'] = ++ cnt;
			clr(cnt);
		}
		u = ac[u].tr[s[i] - 'a'];
	}
	ac[u].End = num;
}

void get_fail() {
	queue<int> q;
	for (int i = 0; i < 26; ++ i) {
		if (ac[0].tr[i] != 0) {
			ac[ac[0].tr[i]].fail = 0;
			q.emplace(ac[0].tr[i]);
		}
	}
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (int i = 0; i < 26; ++ i) {
			if (ac[u].tr[i]) {
				ac[ac[u].tr[i]].fail = ac[ac[u].fail].tr[i];
				q.emplace(ac[u].tr[i]);
			} else {
				ac[u].tr[i] = ac[ac[u].fail].tr[i];
			}
		}
	}
}

void ask(char* s) {
	int l = strlen(s);
	int u = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur; cur = ac[cur].fail) {
			++ Ans[ac[cur].End].first;
		}
	}
}

void work() {
	cnt = 0;
	clr(0);
	for (int i = 1; i <= n; ++ i) {
		cin >> st[i];
		Insert(st[i], i);
		Ans[i].first = 0;
		Ans[i].second = i;
	}
	get_fail();
	scanf("%s", s + 1);
	ask(s + 1);
	sort(Ans + 1, Ans + n + 1, [](pii x, pii y) {
		return x.first == y.first ? x.second < y.second : x.first > y.first;
	});
	int l = 1;
	printf("%d\n", Ans[1].first);
	while (Ans[l].first == Ans[1].first) {
		cout << st[Ans[l].second] << '\n';
		++ l;
	}
}

int main() {
	n = read<int>();
	while (n) {
		work();
		n = read<int>();
	}
	return 0;
}

優化

先拿這道題來引入。P5357 【模板】AC 自動機(二次加強版) - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

你會發現它與 P3796 【模板】AC 自動機(加強版) - 洛谷 | 電腦科學教育新生態 (luogu.com.cn) 十分的相似,似乎只要將最後的找出現次數最大的模式串改為輸出所有模式串的出現次數就行了 反正當時我是這樣想的,然後略微修改程式碼後交上發現。

果然,二次加強版就是不一樣……

重新讀題,意外發現最後一句話:資料不保證任意兩個模式串不相同

???不保證,讀錯題了!(不要犯這樣的低階錯誤),這裡還是比較簡單的,只需要判一下重就好了,直接上程式碼,相信看到這裡的聰明的你一定可以看懂它!修改的主要位置加上註釋了。

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 2e5 + 5;
const int M = 2e6 + 5;

int n, tot;
int ans[N], mp[N];
string st[N];
char s[M];
queue<int> q;

struct node {
	int End, fail;
	int tr[26];
} ac[N];

void Insert(string s, int num) {
	int l = s.length(), u = 0;
	for (int i = 0; i < l; ++ i) {
		if (!ac[u].tr[s[i] - 'a']) {
			ac[u].tr[s[i] - 'a'] = ++ tot;
		}
		u = ac[u].tr[s[i] - 'a'];
	}
	if (!ac[u].End) {// 修改點 1
		ac[u].End = num;
	}
	mp[num] = ac[u].End;
}

void get_fail() {
	for (int i = 0; i < 26; ++ i) {
		if (ac[0].tr[i]) {
			ac[ac[0].tr[i]].fail = 0;
			q.emplace(ac[0].tr[i]);
		}
	}
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (int i = 0; i < 26; ++ i) {
			if (ac[u].tr[i]) {
				ac[ac[u].tr[i]].fail = ac[ac[u].fail].tr[i];
				q.emplace(ac[u].tr[i]);
			} else {
				ac[u].tr[i] = ac[ac[u].fail].tr[i];
			}
		}
	}
}

void ask(char* s) {
	int l = strlen(s), u = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur; cur = ac[cur].fail) {
			++ ans[ac[cur].End];
		}
	}
}

int main() {
	n = read<int>();
	for (int i = 1; i <= n; ++ i) {
		cin >> st[i];
		Insert(st[i], i);
	}
	get_fail();
	scanf("%s", s + 1);
	ask(s + 1);
	for (int i = 1; i <= n; ++ i) {
		printf("%d\n", ans[mp[i]]); // 修改點 2
	}
	return 0;
}

再次提交,得到了這樣的結果。

沒辦法,去 \(\texttt{OI-Wiki}\) 上看了看,發現原來有優化,優化的方式使用 拓撲排序

不會拓撲排序的朋友先去學習一下拓撲排序吧。拓撲排序 - OI Wiki (oi-wiki.org)

我們為什麼會 T 呢?

看這段程式碼

void ask(char* s) {
	int l = strlen(s), u = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		for (int cur = u; cur; cur = ac[cur].fail) {
			++ ans[ac[cur].End];
		}
	}
}

我們沿著 fail 指標一步一步地跳,對於下面的圖。

我們假設:

先搜到 \(14\) 號節點,答案更新;然後搜到了 \(13\) 號節點,答案更新,再找到 \(14\) 號節點,答案更新;之後搜到了 \(11\) 號節點,順著 fail 答案更新;再之後搜到了 \(8\) 號節點,順著 fail 答案更新。

你會發現,效率慢的很!然後就被這道題卡了。

如何提高效率的,我們可以在 \(8、11、13、14\) 號節點上各打上標記,然後從 \(8\) 號開始,標記順著 fail 傳遞過去,最後統計的答案為:\(8\) 號統計了 \(1\) 次,\(11\) 號統計了 \(2\) 次,\(13\) 號統計了 \(3\) 次,\(14\) 號統計了 \(4\) 次,這樣統計的答案與一次又一次地更新是一樣的,但是這種方法效率高了很多。

具體怎麼實現呢,就用拓撲排序,把 fail 指標作為邊,最後 fail 指標一定不會成環,所以可以跑拓撲排序,修改一下程式碼就可以了。

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 2e5 + 5;
const int M = 2e6 + 5;

int n, tot;
int ans[N], mp[N], in[N];
string st[N];
char s[M];
queue<int> q;

struct node {
	int End, fail, tag;
	int tr[26];
} ac[N];

void Insert(string s, int num) {
	int l = s.length(), u = 0;
	for (int i = 0; i < l; ++ i) {
		if (!ac[u].tr[s[i] - 'a']) {
			ac[u].tr[s[i] - 'a'] = ++ tot;
		}
		u = ac[u].tr[s[i] - 'a'];
	}
	if (!ac[u].End) {
		ac[u].End = num;
	}
	mp[num] = ac[u].End;
}

void get_fail() {
	for (int i = 0; i < 26; ++ i) {
		if (ac[0].tr[i]) {
			ac[ac[0].tr[i]].fail = 0;
			q.emplace(ac[0].tr[i]);
		}
	}
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		for (int i = 0; i < 26; ++ i) {
			if (ac[u].tr[i]) {
				ac[ac[u].tr[i]].fail = ac[ac[u].fail].tr[i];
				q.emplace(ac[u].tr[i]);
				++ in[ac[ac[u].fail].tr[i]];
			} else {
				ac[u].tr[i] = ac[ac[u].fail].tr[i];
			}
		}
	}
}

void ask(char* s) {
	int l = strlen(s), u = 0;
	for (int i = 0; i < l; ++ i) {
		u = ac[u].tr[s[i] - 'a'];
		++ ac[u].tag; // 修改部分 1
	}
}

void topsort() { // 修改部分 2
	for (int i = 1; i <= tot; ++ i) {
		if (!in[i]) {
			q.emplace(i);
		}
	}
	while (!q.empty()) {
		int fr = q.front();
		q.pop();
		ans[ac[fr].End] = ac[fr].tag;
		int u = ac[fr].fail;
		ac[u].tag += ac[fr].tag;
		if (! (-- in[u])) {
			q.emplace(u);
		}
	}
}

int main() {
	n = read<int>();
	for (int i = 1; i <= n; ++ i) {
		cin >> st[i];
		Insert(st[i], i);
	}
	get_fail();
	scanf("%s", s + 1);
	ask(s + 1);
	topsort();
	for (int i = 1; i <= n; ++ i) {
		printf("%d\n", ans[mp[i]]);
	}
	return 0;
}

然後,我們就得到了想要的 AC!

完結!