FWT/快速沃爾什變換 入門指南

2023-03-17 18:01:24

來學點好玩的。


引入

我們也許學過,\(FFT\) 可以解決一類折積:

\[C_i=\sum^{k+j=i} A_iB_j \]

現在我們稍微變一下式子:

\[C_i=\sum^{i=k \And j} A_kB_j \]

\[C_i=\sum^{i =k\mid j} A_kB_j \]

\[C_i=\sum^{i=k \oplus j} A_kB_j \]

上面那個圓圓的東西是互斥或

怎麼求?

\(FWT\)

也就是這個式子:

\[C_i=\sum^{i =k\mid j} A_kB_j \]

\(FWT\) 的定義

模仿 \(FFT\),對於 \(A\) 我們想要在可接受時間內得到一個 \(FWT(A)\),使得

\[FWT(C)_i=FWT(A)_i\times FWT(B)_i \]

\(i\) 代表第 \(i\) 個位置的數。

這樣我們就可以 \(O(n)\) 暴力乘起來了。

因此我們構造 \(FWT(A)_i=\sum^{}_{j|i=i}A_j\)

現在我們來證明這個構造的正確性。

\[FWT(A)_i\times FWT(B)_i \]

\[=\sum^{}_{j|i=i}A_j\times \sum^{}_{k|i=i}B_k \]

\[=\sum^{}_{j|i=i}\sum^{}_{k|i=i}A_j B_k \]

因為可以從 \(j|i=i\)\(k|i=i\) 中得出 \((j|k)|i=i\),所以我們還可以消去另一個式子

\[=\sum^{}_{(j|k)|i=i}A_j B_k \]

根據 \(FWT\) 的定義,這個式子就是 \(FWT(C)_i\)。證畢。

另一種理解

涉及到了或運算,我們不妨把數全都變成二進位制。於是我們想到一種經典轉換:將二進位制中的 \(0\)\(1\) 轉化為集合中一個數選還是不選。那麼或操作代表什麼呢?

bingo!兩個集合的並集!

於是我們可以把上面的式子改寫一個形式:

\[C_i=\sum^{i=k \mid j} A_kB_j \]

變成

\[C_i=\sum^{i =k\bigcup j} A_kB_j \]

注意看,這時 \(i,j,k\) 都是集合,只不過我們將其用二進位制表示.

這樣我們可以改寫 \(FWT\) 的定義。

重新來一遍,\(FWT(A)_i=\sum^{}_{k \subseteq i} A_k\)

因此我們為 \(FWT\) 找到了新的定義,他代表著集合 \(i\) 的子集之和。我們有了更自然的推導:

\[\sum^{}_{j \subseteq i}A_j\times \sum^{}_{k \subseteq i}B_k \]

\[=\sum^{}_{j,k \subseteq i}A_j\times B_k \]

\[=\sum^{}_{x \subseteq i}\sum^{}_{j\bigcup k = x }A_j\times B_k \]

\[=\sum^{}_{x \subseteq i}C_x \]

\[=FWT(c)_x \]

換句話說,我們對子集做了個字首和操作(發現了嗎?子集的字首和進行或折積與普通字首和進行加法折積具有相似性),並用字首和相乘代替了原來的 \(O(n^2)\) 相乘。

如何變化

我們把原序列 \(A\) 按下標最高位是 \(0\) 還是 \(1\) 分成兩部分 \(A_0\)\(A_1\) 分治求解。顯然,前半部分(最高位為 \(0\) 的部分)就是 \(FWT(A_0)\),所以我們考慮後半部分的答案。

後半部分最高位為 \(1\),因此此時「子集」這一概念不僅包含分治處理的他子集,還包括把最大值變為 \(0\) 後的,序列 \(A0\) 中同一位置的子集。要將 \(A_0\) 中的同一位置加到當前答案上。

寫成數學形式就是:

\[FWT(A) = merge(FWT(A_0),FWT(A_0)+FWT(A_1)) \]

上面的 \(merge\) 代表拼接,就是字面意思。

於是我們就能寫出分治遞迴程式碼了!但為了常數著想,我們試著把遞迴這一步驟去掉。

去掉的部分並不難寫,我們按照層數從小到大遞迴,不難發現第 \(i\) 層(從 \(0\) 開始編號,最底層為 \(0\))就是列舉第 \(i\) 位是 \(0\) 還是 \(1\),並且亂填其他數進行轉移。

程式碼也簡單:

void OR(mint *f){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)// k 為當前的層,o 僅用於窮舉左邊進行轉移
        for(int i = 0; i < n; i += o)// 窮舉左邊
            for(int j = 0; j < k; j++){ // 窮舉右邊
                f[i + j + k] = f[i + j + k] + f[i + j];
            }
}

如何轉回來

再看一眼轉移的式子

\[FWT(A) = merge(FWT(A_0),FWT(A_0)+FWT(A_1)) \]

思考只有兩個數的情況。此時 \(1\) 位置是不會變的,\(2\) 位置加上了 \(1\) 位置的貢獻,要減去。

我們發現更大的情況也是一樣的,只要依次把前面的貢獻減去就好。

void IOR(mint *f){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)// k 為當前的層,o 僅用於窮舉左邊進行轉移
        for(int i = 0; i < n; i += o)// 窮舉左邊
            for(int j = 0; j < k; j++){ // 窮舉右邊
                f[i + j + k] = f[i + j + k] - f[i + j];
            }
}

這兩份程式碼顯然是可以合併的。因此我們得到了 \(FWT\) 或 的全過程。

void OR(mint *f, int type){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++){
                f[i + j + k] = f[i + j + k] + (f[i + j] * mint(type));
            }
}

\(FWT\)

和 或 差不多,只是要從 \(1\) 轉移到 \(0\)

可以發現,實際上我們用子集字尾和優化了運算。

void AND(mint *f, int type){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
                f[i + j] = f[i + j] + (f[i + j + k] * mint(type));
}

\(FWT\) 互斥或

\[C_i=\sum^{i=k \oplus j} A_kB_j \]

很遺憾,我並沒有發現這個東西的集合意義,如果有大佬知道可以告訴我。。。

正著轉化

思考 \(FWT\) 的作用,我們想要把 \(A_kB_j\) 變成 \(A_iB_i\) 的形式,以此來簡化運算。

我們考慮這樣 \(n\)\(n\) 維向量 \(b\)\(b(i)\) 只有下標 \(i\) 處是 \(1\),其他位置都是 \(0\)

現在我們把 \(FWT\) 後的 \(A,B\) 看作係數,此時顯然 \(A_1b(1),A_2b(2),...,A_nb(n)=A_1,A_2,A_3,A_4...,A_n\)

顯然,互斥或折積對於乘法有分配律。

設互斥或折積為 \(\ast\),則

\[(\sum^{n}_{i=1} A_ib(i)) \ast (\sum^{n}_{i=1} B_jb(j)) \]

\[=(\sum^{n}_{i=1}\sum^{n}_{j=1} A_iB_j (b(i)\ast b(j)) \]

發現後面的東西可以簡單表示,即 \(b_i \ast b_i = b_i,b_i \ast b_j = 0(i \neq j)\)

那麼整個式子就是我們尋找的形式:

\[\sum^{n}_{i=1} A_iB_i \]

而我們要做的事情無非是求出 \(FWT\) 之前的 \(b_i\)

太長不看版:互斥或 \(FWT\) 與原序列線性相關

既然這樣,我們設 \(FWT(A)_x=\sum^{n}_{i=1} g(x,i)A_i\)

那麼因為 \(FWT(C)_x=FWT(A)_x\times FWT(b)_x\)

所以 \(\sum^{n}_{k=1} g(x,k)C_k=\sum^{n}_{i=1} g(x,i)A_i \times \sum^{n}_{j=1} g(x,j)B_j\)

整理一下可以得出:

\[\sum^{n}_{k=1} g(x,k)C_k=\sum^{n}_{i=1}\sum^{n}_{j=1} g(x,i)g(x,j)\times A_iB_j \]

\(C_k\)\(A,B\) 表示可得:

\[\sum^{n}_{k=1} g(x,k)\sum^{k=i \oplus j} A_iB_j=\sum^{n}_{i=1}\sum^{n}_{j=1} g(x,i)g(x,j)\times A_iB_j \]

更改求和順序,我們列舉 \(i,j\) 可得:

\[\sum^{n}_{i=1}\sum^{n}_{j=1} g(x,i \oplus j) A_iB_j=\sum^{n}_{i=1}\sum^{n}_{j=1} g(x,i)g(x,j)\times A_iB_j \]

於是我們發現了 \(g\) 的關係:

\[g(x,i \oplus j) = g(x,i)g(x,j) \]

現在問題來了,與 \(i,j\) 相關的什麼東西,使互斥或之後的值等於原來兩值的乘積?

於是我們可以想到有人託夢給我奇偶性。

具體的,我們發現互斥或前後 \(1\) 的個數奇偶性不變。原因如下:

按每一位依次考慮。如果第 \(i\) 位互斥或後為 \(1\),那麼原來必定有且僅有一個 \(1\)。個數不變

如果為 \(0\),要麼是兩個 \(0\),此時 \(1\) 的個數不變,要麼是兩個 \(1\),此時 \(1\) 的個數減 \(2\),奇偶性仍不變。

所以我們定義 \(g(x,i)=(-1)^{|i \bigcap x|}\)。那麼上式就等價於:

\[(-1)^{|(i \oplus j) \bigcap x|} = (-1)^{|i \bigcap x|}(-1)^{|j \bigcap x|} \]

根據上面的推論,左右兩邊奇偶性不變,與 後無非是減去兩個相同的數,奇偶性還是不變。

於是我們得出 \(FWT\) 的轉移式:

\[FWT(A)_x=\sum^{n}_{i=1} (-1)^{|i \bigcap x|}A_i \]

如何求解

考慮模仿前兩個 \(FWT\) 的形式,討論最高位 \(i\)\(0\) 和為 \(1\) 兩種情況。

原來最高位為 \(0\)\(FWT\) 後的前 \(2^{i-1}\) 個數最高位還是 \(0\)。由於 \(1 \And 0=0\),所以後 \(2^{i-1}\) 個數的貢獻為正。前半部分答案為 \(FWT(A_0)+FWT(A_1)\)

\(FWT\) 後的後 \(2^{i-1}\) 個數最高位變成了 \(1\),此時 \(A_0\) 的貢獻還是正(因為 \(1 \And 0=0\))。但是此時後半部分加了 \(1\),於是貢獻要取反。後半部分答案為 \(FWT(A_0)-FWT(A_1)\)

所以我們得出:

\[FWT(A) = merge(FWT(A_0)+FWT(A_1),FWT(A_0)-FWT(A_1)) \]

當然,參照 或 FWT,我們可以寫出不依賴遞迴的程式:

void XOR(mint *f){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)//具體意義參考 FWT 或
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j ++){
                mint x = f[i + j], y = f[i + j + k];
                f[i + j] = x + y;
                f[i + j + k] = x - y;
            }
}

求逆變換

實際上就是把貢獻減去

\[IFWT(A) = merge(\frac{IFWT(A_0)+IFWT(A_1)}{2},\frac{IFWT(A_0)-IFWT(A_1))}{2} \]

顯然這兩個東西是可以合併的。於是我們可以得出模板的完整程式碼:

#include<bits/stdc++.h>
using namespace std;

#define forp(i, a, b) for(int i = (a);i <= (b);i ++)
#define forc(i, a, b) for(int i = (a);i >= (b);i --)

const int maxn = 6e5 + 5;
const int mod = 998244353;

int read(){
    int u;cin >> u;return u;
}

class mint{
    private : int v;
    public:
        mint(){}
        int operator()(void)const{
            return v;
        }
        mint (const int &u){ 
            v = u % mod; 
        }
        mint operator+(const mint &a) const{ 
            int x = a.v + v;
            if(x >= mod) return mint(x - mod);
            if(x < 0) return mint(x + mod);
            return x;
        }
        mint operator-(const mint& a)const{
			return v < a.v ? v - a.v + mod : v - a.v;
		}
        mint operator*(const mint &a) const{
            return mint((1ll * a.v * v) % mod);
        }
};

mint qpow(mint u, int v){
    mint ans = mint(1);
    while(v){
        if(v & 1) ans = ans * u;
        u = u * u;
        v >>= 1;
    }
    return ans;
}
mint inv2 = qpow(2, mod - 2);

int n;
mint A[maxn], B[maxn], C[maxn];
mint g[maxn];

void OR(mint *f, int type){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++){
                f[i + j + k] = f[i + j + k] + (f[i + j] * mint(type));
            }
}

void AND(mint *f, int type){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
                f[i + j] = f[i + j] + (f[i + j + k] * mint(type));
}

void XOR(mint *f, int type){
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j ++){
                mint x = f[i + j], y = f[i + j + k];
                f[i + j] = x + y;
                f[i + j + k] = x - y;
                if(type == -1){
                    f[i + j] = f[i + j] * inv2;
                    f[i + j + k] = f[i + j + k] * inv2;
                }
            }
}

signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);

    n = (1 << read());
    forp(i, 0, n - 1) A[i] = mint(read());
    forp(i, 0, n - 1) B[i] = mint(read());

    OR(A, 1);OR(B, 1);
    forp(i, 0, n - 1) C[i] = (A[i] * B[i]);
    OR(C, -1);
    forp(i, 0, n - 1) cout << C[i]() << ' ';
    cout << endl;
    OR(A, -1);OR(B, -1);

    AND(A, 1);AND(B, 1);
    forp(i, 0, n - 1) C[i] = (A[i] * B[i]);
    AND(C, -1);
    forp(i, 0, n - 1) cout << C[i]() << ' ';
    cout << endl;
    AND(A, -1);AND(B, -1);

    XOR(A, 1);XOR(B, 1);
    forp(i, 0, n - 1) C[i] = (A[i] * B[i]);
    XOR(C, -1);
    forp(i, 0, n - 1) cout << C[i]() << ' ';
    cout << endl;
    XOR(A, -1);XOR(B, -1);
    return 0;
}

大概就是這樣。。。應用也許會另外開坑吧。


後記

感謝 xht 的部落格,從這篇部落格里我學到了 FWT 的基礎知識。

感謝同校大佬 yllcm 為本人解釋符號與定義。

感謝萬能的U群群友 Untitled_unrevised 解釋 FWT 的目的。

感謝萬能的U群群友 rqy 學姐與 櫻初音鬥橡皮 解釋為什麼互斥或 FWT 是線性變換順便發現了我線上性代數方面的巨大缺口