變分自編碼器(VAE)公式推導

2023-07-01 15:00:34

論文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]

本文記錄了我在學習 VAE 過程中的一些公式推導和思考。如果你希望從頭開始學習 VAE,建議先看一下蘇劍林的部落格(本文末尾有連結)。

VAE 的整體框架

VAE 認為,隨機變數 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由兩個隨機過程得到:

  1. 根據先驗分佈 \(p(\boldsymbol{z})\) 生成隱變數 \(\boldsymbol{z}\)
  2. 根據條件分佈 \(p(\boldsymbol{x} | \boldsymbol{z})\)\(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\)

於是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我們所需要的生成模型。

一種樸素的想法是:先用亂數生成器生成隱變數 \(\boldsymbol{z}\),然後用 \(p(\boldsymbol{x} | \boldsymbol{z})\)\(\boldsymbol{z}\) 中生成出(或者說重構出) \(\boldsymbol{x}\),通過最小化重構損失來訓練模型。這個想法的問題在於:我們無法找到生成的樣本與原始樣本之間的對應關係,重構損失算不了,無法訓練。

VAE 的做法是引入後驗分佈 \(p(\boldsymbol{z} | \boldsymbol{x})\),訓練過程變為:

  1. 取樣一批原始樣本 \(\boldsymbol{x}\)
  2. \(p(\boldsymbol{z} | \boldsymbol{x})\) 獲得每個樣本 \(\boldsymbol{x}\) 對應的隱變數 \(\boldsymbol{z}\)
  3. \(p(\boldsymbol{x} | \boldsymbol{z})\) 從隱變數 \(\boldsymbol{z}\) 中重構出 \(\boldsymbol{x}\),通過最小化重構損失來訓練模型。

從這個角度來看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相當於編碼器\(p(\boldsymbol{x} | \boldsymbol{z})\) 相當於解碼器,訓練結束後只需要保留解碼器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可。

除了重構損失以外,VAE 還有一項 KL 散度損失,希望近似的後驗分佈 \(q(\boldsymbol{z} | \boldsymbol{x})\) 儘量接近先驗分佈 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度。

變分下界的推導

現有 \(N\) 個由分佈 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的樣本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我們可以使用極大似然估計從這些樣本中估計出分佈的引數 \(\boldsymbol{\theta}\),即

\[\begin{aligned} \boldsymbol{\theta} & = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}). \end{aligned} \]

後驗分佈 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因為分母處的邊緣分佈 \(p(\boldsymbol{x})\) 積不出來。具體來說,聯合分佈 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的表示式非常複雜,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 這個積分找不到解析解。

需要使用變分推斷解決後驗分佈無法計算的問題。我們使用一個形式已知的分佈 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\)近似後驗分佈 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),於是有

\[\begin{aligned} \log p(\boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\ & \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}). \end{aligned} \]

利用 KL 散度大於等於 0 這一特性,我們得到了對數似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一個下界 \(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),於是可以將最大化對數似然改為最大化這個下界。

這個下界可以進一步寫成

\[\begin{aligned} L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\ \end{aligned} \]

其中的第一項是 KL 散度損失,第二項是重構損失。

KL 散度損失

使用標準正態分佈作為先驗分佈,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\)

使用一個由 MLP 的輸出來引數化的正態分佈作為近似後驗分佈,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\)

選擇正態分佈的好處在於 KL 散度的這個積分可以寫出解析解,訓練時直接按照公式計算即可,無需通過取樣的方式來算積分。

由於我們選擇的是各分量獨立的多元正態分佈,因此只需要推導一元正態分佈的情形即可:

\[\begin{aligned} \mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)] & = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\ & = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\ & = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\ & = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\ & = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2). \end{aligned} \]

解釋一下倒數第三行的三個積分:

  1. \(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函數的積分,也就是 1。
  2. \(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定義,也就是 \(\sigma^2\)
  3. \(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正態分佈的二階矩,結果為 \(\mu^2 + \sigma^2\)

重構損失

伯努利分佈模型

\(\boldsymbol{x}\) 是二值向量時,可以用伯努利分佈(兩點分佈)來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即認為向量 \(\boldsymbol{x}\) 的每個維度都服從對應的相互獨立的伯努利分佈。使用一個 MLP 來計算各維度所對應的伯努利分佈的引數,第 \(i\) 維伯努利分佈的引數為 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),於是有

\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i}, \]

\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i). \]

其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的維度。可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價於最小化交叉熵損失。

正態分佈模型

\(\boldsymbol{x}\) 是實值向量時,可以用正態分佈來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\)。使用一個 MLP 來計算正態分佈的引數,於是有

\[\begin{aligned} p(\boldsymbol{x}|\boldsymbol{z}) & = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\ & = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\ & = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right), \end{aligned} \]

\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}. \]

很多時候我們會假設 \(\sigma_i^2\) 是一個常數,於是 MLP 只需要輸出均值引數 \(\boldsymbol{\mu}\) 即可。此時有

\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2. \]

可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價於最小化 MSE 損失。

重引數化技巧

需要使用重引數化技巧解決取樣 \(z\) 時不可導的問題。解決的思路是先從無引數分佈中取樣一個 \(\varepsilon\),再通過變換得到 \(z\)

\(N(\mu, \sigma^2)\) 中取樣一個 \(z\),相當於先從 \(N(0, 1)\) 中取樣一個 \(\varepsilon\),然後令 \(z = \mu + \varepsilon\cdot\sigma\)

相關知識

技巧,通過取對數把乘除變成加減:

\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b. \]

隨機變數的函數的期望:

\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x, \]

利用此公式可以將積分改寫成期望的形式,這樣就可以用取樣的方式計算積分了(蒙特卡羅積分法)。

條件概率密度的定義:

\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)}, \]

此處的 \(p\) 並不是概率而是概率密度函數,但是這個公式在形式上跟條件概率公式是一樣的。

參考資料

蘇劍林的 VAE 系列部落格:

15 分鐘瞭解變分推理: