論文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]
本文記錄了我在學習 VAE 過程中的一些公式推導和思考。如果你希望從頭開始學習 VAE,建議先看一下蘇劍林的部落格(本文末尾有連結)。
VAE 的整體框架
VAE 認為,隨機變數 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由兩個隨機過程得到:
- 根據先驗分佈 \(p(\boldsymbol{z})\) 生成隱變數 \(\boldsymbol{z}\)。
- 根據條件分佈 \(p(\boldsymbol{x} | \boldsymbol{z})\) 由 \(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\)。
於是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我們所需要的生成模型。
一種樸素的想法是:先用亂數生成器生成隱變數 \(\boldsymbol{z}\),然後用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 從 \(\boldsymbol{z}\) 中生成出(或者說重構出) \(\boldsymbol{x}\),通過最小化重構損失來訓練模型。這個想法的問題在於:我們無法找到生成的樣本與原始樣本之間的對應關係,重構損失算不了,無法訓練。
VAE 的做法是引入後驗分佈 \(p(\boldsymbol{z} | \boldsymbol{x})\),訓練過程變為:
- 取樣一批原始樣本 \(\boldsymbol{x}\)。
- 用 \(p(\boldsymbol{z} | \boldsymbol{x})\) 獲得每個樣本 \(\boldsymbol{x}\) 對應的隱變數 \(\boldsymbol{z}\)。
- 用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 從隱變數 \(\boldsymbol{z}\) 中重構出 \(\boldsymbol{x}\),通過最小化重構損失來訓練模型。
從這個角度來看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相當於編碼器,\(p(\boldsymbol{x} | \boldsymbol{z})\) 相當於解碼器,訓練結束後只需要保留解碼器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可。
除了重構損失以外,VAE 還有一項 KL 散度損失,希望近似的後驗分佈 \(q(\boldsymbol{z} | \boldsymbol{x})\) 儘量接近先驗分佈 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度。
變分下界的推導
現有 \(N\) 個由分佈 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的樣本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我們可以使用極大似然估計從這些樣本中估計出分佈的引數 \(\boldsymbol{\theta}\),即
\[\begin{aligned}
\boldsymbol{\theta}
& = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}).
\end{aligned}
\]
後驗分佈 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因為分母處的邊緣分佈 \(p(\boldsymbol{x})\) 積不出來。具體來說,聯合分佈 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的表示式非常複雜,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 這個積分找不到解析解。
需要使用變分推斷解決後驗分佈無法計算的問題。我們使用一個形式已知的分佈 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\) 來近似後驗分佈 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),於是有
\[\begin{aligned}
\log p(\boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\
& \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}).
\end{aligned}
\]
利用 KL 散度大於等於 0 這一特性,我們得到了對數似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一個下界 \(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),於是可以將最大化對數似然改為最大化這個下界。
這個下界可以進一步寫成
\[\begin{aligned}
L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\
\end{aligned}
\]
其中的第一項是 KL 散度損失,第二項是重構損失。
KL 散度損失
使用標準正態分佈作為先驗分佈,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\)。
使用一個由 MLP 的輸出來引數化的正態分佈作為近似後驗分佈,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\)。
選擇正態分佈的好處在於 KL 散度的這個積分可以寫出解析解,訓練時直接按照公式計算即可,無需通過取樣的方式來算積分。
由於我們選擇的是各分量獨立的多元正態分佈,因此只需要推導一元正態分佈的情形即可:
\[\begin{aligned}
\mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)]
& = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\
& = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\
& = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\
& = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\
& = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2).
\end{aligned}
\]
解釋一下倒數第三行的三個積分:
- \(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函數的積分,也就是 1。
- \(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定義,也就是 \(\sigma^2\)。
- \(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正態分佈的二階矩,結果為 \(\mu^2 + \sigma^2\)。
重構損失
伯努利分佈模型
當 \(\boldsymbol{x}\) 是二值向量時,可以用伯努利分佈(兩點分佈)來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即認為向量 \(\boldsymbol{x}\) 的每個維度都服從對應的相互獨立的伯努利分佈。使用一個 MLP 來計算各維度所對應的伯努利分佈的引數,第 \(i\) 維伯努利分佈的引數為 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),於是有
\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i},
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i).
\]
其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的維度。可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價於最小化交叉熵損失。
正態分佈模型
當 \(\boldsymbol{x}\) 是實值向量時,可以用正態分佈來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\)。使用一個 MLP 來計算正態分佈的引數,於是有
\[\begin{aligned}
p(\boldsymbol{x}|\boldsymbol{z})
& = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\
& = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\
& = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right),
\end{aligned}
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}.
\]
很多時候我們會假設 \(\sigma_i^2\) 是一個常數,於是 MLP 只需要輸出均值引數 \(\boldsymbol{\mu}\) 即可。此時有
\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2.
\]
可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價於最小化 MSE 損失。
重引數化技巧
需要使用重引數化技巧解決取樣 \(z\) 時不可導的問題。解決的思路是先從無引數分佈中取樣一個 \(\varepsilon\),再通過變換得到 \(z\)。
從 \(N(\mu, \sigma^2)\) 中取樣一個 \(z\),相當於先從 \(N(0, 1)\) 中取樣一個 \(\varepsilon\),然後令 \(z = \mu + \varepsilon\cdot\sigma\)。
相關知識
技巧,通過取對數把乘除變成加減:
\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b.
\]
隨機變數的函數的期望:
\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x,
\]
利用此公式可以將積分改寫成期望的形式,這樣就可以用取樣的方式計算積分了(蒙特卡羅積分法)。
條件概率密度的定義:
\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)},
\]
此處的 \(p\) 並不是概率而是概率密度函數,但是這個公式在形式上跟條件概率公式是一樣的。
參考資料
蘇劍林的 VAE 系列部落格:
15 分鐘瞭解變分推理: