演演算法學習筆記(20): AC自動機

2023-03-26 15:00:34

AC自動機

前置知識

使用場景

AC自動機是一種著名的多模式匹配演演算法。

可以完成類似於KMP演演算法的工作,但是由單字串的匹配變成了多字串的匹配。

一般來說,會有很多子串,和一個母串。問題常是求字串在母串中的出現情況(包括位置,次數,等等)

演演算法思想與流程

我在Trie樹一文中提到過這樣一句話

而AC自動機的核心就在於通過對Trie樹進行處理,使得在處理母串的資訊時可以快速的進行狀態轉移。

可以類比KMP的演演算法流程,但是這不重要

例如子串有 aa, ab, abc, b。母串為 ababcba

由於我們是通過母串進行狀態轉移,所以需要先把所有字串的資訊搞定

我們可以先處理子串,建一棵Trie樹

明顯,對於一個字串的匹配,是不可能在樹上一路到底的,所以要構建匹配失敗時的回退機制。也就是需要構建失配指標。

那麼失配指標是幹什麼的?也就是用來在 Trie 樹上向上跳,找到可以轉移的一個節點,進行狀態轉移。

假如我現在在3號節點,並且我下一個需要轉移的狀態是 b,很明顯,我此時應該回退到1節點(其上第一個可以通過 b 轉移的節點)並轉移到4節點。如果再來一個 b,也只能向上走到0號節點,然後轉移到2號節點。

如此看來,我們完全可以暴力向上跳找到可轉移的狀態或者到達根為止。但是,這明顯不夠優秀,我們完全可以繼承其子節點的。也就是繼承 fail 的子節點。使得不需要暴力向上跳。

那說了半天,fail 到底指向啥?

假設父節點到當前節點轉移的狀態為 x,父節點之上第一個可以通過 x 轉移到下一個節點的節點為 u,則 fail 指向 u 通過 x 轉移過後的節點。

其實還有另一種解釋的方法

fail 指向 p 代表當前串的最長已知字尾。

例如 aa 的最長已知字尾為 a,所以 3號節點的 fail 指向 1號節點;abc 的最長已知字尾為空,所以 5 號節點的 fail 指向根節點。

好混亂,我盡力了……

那麼核心程式碼……就是利用 BFS 來處理

void procFail(int * q) {
    int head(0), tail(0);
    for (int i(0); i < 26; ++i) {
        if (kids[0][i]) q[tail++] = kids[0][i];
    }

    while (head ^ tail) {
        int x = q[head++];
        for (int i(0); i < 26; ++i) {
            if (kids[x][i]) {
                fail[kids[x][i]] = kids[fail[x]][i];
                q[tail++] = kids[x][i];
            } else kids[x][i] = kids[fail[x]][i];
        }
    } // procFail end
}

注意事項:一般來說,把 0 號作為根節點會比較方便。反正 0 上不可能有資訊儲存。

插入部分我就不需要講了

匹配的判斷

如何判斷當前狀態有沒有匹配任何一個字串,只需要不斷向上跳 fail,看跳到的節點是不是代表著字串。

拿模板:【模板】AC 自動機(簡單版) - 洛谷 為例。

插入的時候在最後標記一下有沒有匹配:

void insert(string &s) {
    int p(0);
    for (int c : s) {
        if (!kids[p][(c -= 'a')]) kids[p][c] = ++usage;
        p = kids[p][c];
    }
    ++cnt[p];
}

在匹配的時候暴力跳就是了:

int ACMatch(string & s) {
    int p(0), ans(0);
    for (int c : s) {
        p = kids[p][(c -= 'a')];
        for (int t(p); t && ~cnt[t]; t = fail[t]) {
            ans += cnt[t], cnt[t] = -1;
        }
    }
    return ans;
}

由於每一個串只能匹配一次,所以這裡採用的清空的策略。並且標記清空,以免重複搜尋。

失配樹的應用

就拿模板題來說吧:【模板】AC 自動機(二次加強版) - 洛谷

他是要求所有字串的出現情況。

那麼,我們先把每一個到達的狀態計數。再通過 fail 指標向上跳求和。

但畢竟不能每一個節點都暴力跳,所以考慮在 fail 樹上求和。

但是,我們不是有一個 qBFS 嗎?其中的 fail 是有序的:對於一個節點 x,其 fail 一定在 x 之前被遍歷到。

所以我們直接使用 q 即可。

那麼合起來大概也就是這樣:

inline void ACMatch(string &s) {
    int p(0);
    for (char c : s) {
        p = kids[p][c - 'a'];
        ++cnt[p];
    }
}

inline void ACCount(int * q) {
    for (int i = usage; i; --i) {
        cnt[fail[q[i]]] += cnt[q[i]];
    }
}

但是每一個特定的字串出現的次數呢?

在插入時記住字串對應的節點,輸出即可。

void insert(string &s, int i) {
    int p(0);
    for (int c : s) {
        if (!kids[p][(c -= 'a')]) kids[p][c] = newNode();
        p = kids[p][c];
    }
    pos[i] = p;
}


inline void ACOutput(int n) {
    for (int i = 1; i <= n; ++i) {
        cout << cnt[pos[i]] << '\n';
    }
}

有這麼一道題:

很明顯,對於每一個位置,我們需要清理能匹配到的最長長度,所以我們需要預處理出最長長度:

inline void ACprepare(int * q) {
    for (int i = 1; i <= usage; ++i) {
        len[q[i]] = max(len[q[i]], len[fail[q[i]]]);
    }
}

在清理時:

inline void ACclean(string &s) {
    int p(0);
    for (unsigned i(0), ie = s.size(); i < ie; ++i) {
        p = kids[p][discrete(s[i])];
        if (len[p]) for (unsigned j = i - len[p] + 1; j <= i; ++j)
            s[j] = '*';
    }
}

由於是參照的字串,所以可以直接修改。

對狀態的理解

在我們考試的時候有這麼一道題:

這道題說難也難,說不難也不難。主要是看對於 AC自動機 狀態轉移的理解到不到位。

在匹配過程中,如果匹配到了出現的 w,那麼就要回到 len(w) 個狀態前,繼續匹配下一個字元。

很明顯,需要用棧,並且由於需要一次彈出多個,所以最好用手寫的棧。

核心程式碼如下:

string sub, pat;
cin >> sub >> pat;
insert(sub), procFail(Q);

int p = 0;
for (int i(0), ie = pat.size(); i < ie; ++i) {
    p = kids[cps[ci]][pat[i] - 'a'];
    cps[++ci] = p, ccs[ci] = pat[i];
    if (match[p]) ci -= sub.size();
}

for (int i = 1; i <= ci; ++i) {
    putchar(ccs[i]);
}

這裡沒有用到 fail,那麼為什麼還要構建失配樹?

這是個好問題,因為,構建失配樹的過程不僅僅構建了失配樹,同時還令節點繼承了其 fail 的子節點,所以需要構建的過程。


最後附上模板題【模板】AC 自動機(二次加強版) - 洛谷的程式碼:

#include <iostream>
#include <algorithm>
#include <string>

using namespace std;
const int N = 1e6 + 7;

int res[N], cnt[N], pos[N];
class ACAutomaton {
private:
	int kids[N][26];
	int fail[N], id[N], usage;
public:
	ACAutomaton() : usage(0) {
	}
	
	inline int newNode() {
		fill_n(kids[++usage], 26, 0);
		cnt[usage] = fail[usage] = id[usage] = 0;
		return usage;
	}
	
	void insert(string &s, int i) {
		int p(0);
		for (int c : s) {
			if (!kids[p][(c -= 'a')]) kids[p][c] = newNode();
			p = kids[p][c];
		}
		pos[i] = p;
	}
	
	void procFail(int * q) {
		int head(0), tail(0);
		for (int i(0); i < 26; ++i) {
			if (kids[0][i])
				fail[kids[0][i]] = 0, q[tail++] = kids[0][i];
		}
		
		while (head ^ tail) {
			int x = q[head++];
			for (int i(0); i < 26; ++i) {
				if (kids[x][i]) {
					fail[kids[x][i]] = kids[fail[x]][i];
					q[tail++] = kids[x][i];
				} else kids[x][i] = kids[fail[x]][i];
			}
		} // procFail end
	}
	
	void debug() {
		for (int i = 0; i <= usage; ++i) {
			printf("node %d (cnt %d) fail to %d:\n\t", i, cnt[i], fail[i]);
			for (int j(0); j < 26; ++j) {
				printf("%d ", kids[i][j]);
			} puts("");
		}
	}
	
	inline void ACMatch(string &s) {
		int p(0);
		for (char c : s) {
			p = kids[p][c - 'a'];
			++cnt[p];
		}
	}
	
	inline void ACCount(int * q) {
		for (int i = usage; i; --i) {
			cnt[fail[q[i]]] += cnt[q[i]];
		}
	}
	
	inline void ACOutput(int n) {
		for (int i = 1; i <= n; ++i) {
			cout << cnt[pos[i]] << '\n';
		}
	}
	
	void clear() {
		usage = -1;
		newNode(); // clear 0
	}
} ac;

int Q[N];
string s; 

int main() {
	cin.tie(0)->sync_with_stdio(false);
	
	int n;
	cin >> n;
	for (int i = 1; i <= n; ++i) {
		cin >> s;
		ac.insert(s, i);
	} ac.procFail(Q);
	
	cin >> s;
	ac.ACMatch(s);
	ac.ACCount(Q);
	ac.ACOutput(n);
	return 0;
}

差不多了……下課