域泛化(domain generalization, DG) [1][2]旨在從多個源域中學習一個能夠泛化到未知目標域的模型。形式化地說,給定\(K\)個訓練的源域資料集\(\mathcal{S}=\left\{\mathcal{S}^k \mid k=1, \cdots, K\right\}\),其中第\(k\)個域的資料被表示為\(\mathcal{S}^k = \left\{\left(x_i^k, y_i^k\right)\right\}_{i=1}^{n^k}\)。這些源域的資料分佈各不相同:\(P_{X Y}^k \neq P_{X Y}^l, 1 \leq k \neq l \leq K\)。域泛化的目標是從這\(K\)個源域的資料中學習一個具有強泛化能力的模型:\(f: \mathcal{X}\rightarrow \mathcal{Y}\),使其在一個未知的測試資料集\(\mathcal{T}\)(即\(\mathcal{T}\)在訓練過程中不可存取且\(P_{X Y}^{\mathcal{T}} \neq P_{X Y}^i \text { for } i \in\{1, \cdots, K\}\))上具有最小的誤差:
這裡\(\mathbb{E}\)和\(\ell(\cdot, \cdot)\)分別為期望和損失函數。域泛化示意圖如下圖所示:
目前為了解決域泛化中的域漂移(domain shift) 問題,已經提出了許多方法,大致以分為下列三類:
資料操作(data manipulation) 這種方法旨在通過資料增強(data augmentation)或資料生成(data generation)方法來豐富資料的多樣性,從而輔助學習更有泛化能力的表徵。其中資料增強方法常利用資料變換、對抗資料增強(adversarial data augmentation)[3]等手段來增強資料;資料生成方法則通過Mixup(也即對資料進行兩兩線性插值)[4]等手段來生成一些輔助樣本。
表徵學習(representation learning) 這種方法旨在通過學習領域不變表徵(domain-invariant representations),或者對領域共用(domain-shared)和領域特異(domain-specific)的特徵進行特徵解耦(feature disentangle),從而增強模型的泛化效能。該類方法我們在往期部落格《尋找領域不變數:從生成模型到因果表徵 》和《跨域推薦:嵌入對映、聯合訓練和解耦表徵》中亦有詳細的論述。其中領域不變表徵的學習手段包括了對抗學習[5]、顯式表徵對齊(如優化分佈間的MMD距離)[6]等等,而特徵解耦則常常通過優化含有互資訊(資訊瓶頸的思想)或KL散度[7]的損失項來達成,其中大多數會利用VAE等生成模型。
學習策略(learning stategy) 這種方法包括了整合學習[8]、元學習[9]等學習正規化。其中,以元學習為基礎的方法則利用元學習自發地從構造的任務中學習元知識,這裡的構造具體而言是指將源域資料集\(\mathcal{S}\)按照域為單位來拆分成元訓練(meta-train)部分\(\bar{\mathcal{S}}\)和元測試(meta-test)部分\(\breve{\mathcal{S}}\)以便對分佈漂移進行模擬,最終能夠在目標域\(\mathcal{T}\)的final-test中取得良好的泛化表現。
然而,目前大多數域泛化方法需要將不同領域的資料進行集中收集。然而在現實場景下,由於隱私性的考慮,資料常常是分散式收集的。因此我們需要考慮聯邦域泛化(federated domain generalization, FedDG) 方法。形式化的說,設\(\mathcal{S}=\left\{\mathcal{S}^1, \mathcal{S}^2, \ldots, \mathcal{S}^K\right\}\)表示在聯邦場景下的\(K\)個分散式的源域資料集,每個源域資料集包含資料和標籤對\(\mathcal{S}^k=\left\{\left(x_i^k, y_i^k\right)\right\}_{i=1}^{n^k}\),取樣自域分佈\(P_{X Y}^k\)。聯邦域泛化的目標是利用\(K\)個分散式的源域學習模型\(f_\theta: \mathcal{X} \rightarrow \mathcal{Y}\),該模型能夠泛化到未知的測試域\(\mathcal{T}\)。聯邦域泛化的架構如下圖所示:
這裡需要注意的是,傳統的域泛化方法常常要求直接對齊表徵或運算元據,這在聯邦場景下是違反資料隱私性的。此外對於跨域的聯邦學習,由於使用者端異構的資料分佈/領域漂移(如不同的影象風格)所導致的模型偏差(bias),直接聚合本地模型的引數也會導致次優(sub-optimal)的全域性模型,從而更難泛化到新的目標域。因此,許多傳統域泛化方法在聯邦場景下都不太可行,需要因地制宜進行修改,下面試舉幾例:
對於資料操作的方法,我們常常需要用其它領域的資料來對某個領域的資料進行增強(或進行新資料的插值生成),而這顯然違反了資料隱私。目前論文的解決方案是不直接傳資料,而傳資料的統計量來對資料進行增強[10],這裡的統計量指圖片的style(即圖片逐通道計算的均值和方差)等等。
對於表徵學習的方法,也需要在對不同域的表徵進行共用/對比的條件下獲得領域不變表徵(或對錶徵進行分解),而傳送表徵事實上也違反了資料隱私。目前論文采用的解決方案包括不顯式對齊表徵,而是使得所有領域的表徵顯式/隱式地對齊一個參考分佈(reference distribution)[11][12],這個參考分佈可以是高斯,也可以由GAN來自適應地生成。
基於學習策略的方法,如元學習也需要利用多個域的資料來構建meta-train和meta-test,並進行元更新(meta-update),而這也違反了資料隱私性。目前論文的解決方案是使用來自其它域的變換後資料來為當前域構造元學習資料集[13],這裡的變換後資料指影象的幅度譜等等。
本篇論文是聯邦域泛化的第一篇工作。這篇論文屬於基於學習策略(採用元學習)的域泛化方法,並通過傳影象的幅度譜(amplitude spectrum),而非影象資料本身來構建原生的元學習任務,從而保證聯邦場景下的資料隱私性。本文方法的框架示意圖如下:
這裡\(K\)為領域/使用者端的個數。該方法使影象的低階特徵——幅度譜在不同使用者端間共用,而使高階語意特徵——相位譜留在本地。這裡再不同使用者端間共用的幅度譜就可以作為多領域/多源資料分佈供本地元學習訓練使用。
接下來我們看原生的元學習部分。元學習的基本思想是通過模擬訓練/測試資料集的領域漂移來學得具有泛化性的模型引數。而在本文中,本地使用者端的領域漂移來自不同分佈的頻率空間。具體而言,對每輪迭代,我們考慮原生的原輸入圖片\(x_{i}^k\)做為meta-train,它的訓練搭檔\(\mathcal{T}_i^{k}\)則由來自其它使用者端的頻域產生,做為meta-test來表示分佈漂移。
設使用者端\(k\)中的圖片\(x^k_i\)由正向傅立葉變換\(\mathcal{F}\)得到的幅度譜為\(\mathcal{A}_i^k \in \mathbb{R}^{H \times W \times C}\),相位譜為\(\mathcal{P}_i^k \in \mathbb{R}^{H \times W \times C}\)(\(C\)為圖片通道數)。本文欲在使用者端之間交換低階分佈也即幅度譜資訊,因此需要先構建一個供所有使用者端共用的distribution bank \(\mathcal{A} = [\mathcal{A}^1, \cdots, \mathcal{A}^K]\),這裡\(A^k = {\{\mathcal{A}^k_i\}}^{n^k}_{i=1}\)包含了來自第\(k\)個使用者端所有圖片的幅度譜資訊,可視為代表了\(\mathcal{X}^k\)的分佈。
之後,作者通過在頻域進行連續插值的手段,將distribution bank中的多源分佈資訊送到本地使用者端。如上圖所示,對於第\(k\)個使用者端的圖片幅度譜\(\mathcal{A}_i^{k}\),我們會將其與另外\(K-1\)個使用者端的幅度譜進行插值,其中與第\(l(l\neq k)\)個外部使用者端的圖片幅度譜\(\mathcal{A}_j\)插值的結果表示為:
這裡\(\mathcal{M}\)是一個控制幅度譜內低頻成分比例的二值掩碼,\(\lambda\)是插值率。然後以此通過反向傅立葉變換生成變換後的圖片:
就這樣,對於第\(k\)個使用者端的輸入圖片\(x^k_i\),我們就得到了屬於不同分佈的\(K-1\)個變換後的圖片資料\(\mathcal{T}^k_i = \{x^{k\rightarrow l}_i\}_{l\neq k}\),這些圖片和\(x^k_i\)共用了相同的語意標籤。
接下來在元學習的每輪迭代中,我們將原始資料\(x^k_i\)做為meta-train,並將其對應的\(K-1\)個由頻域產生的新資料\(\mathcal{T}^k_i\)做為meta-test來表示分佈漂移,從而完成在當前使用者端的inner-loop的引數更新。
具體而言,元學習正規化可以被分解為兩步:
第一步 模型引數\(\theta^k\)在meta-train上通過segmentaion Dice loss \(\mathcal{L}_{seg}\)來更新:
這裡引數\(\beta\)表示內層更新的學習率。
第二步 在meta-test資料集\(\mathcal{T}^k_i\)上使用元目標函數(meta objective)\(\mathcal{L}_{meta}\)對已更新的引數\(\hat{\theta}^k\)進行進一步元更新。
這裡特別重要的是,第二步所要優化的目標函數由在第一部中所更新的引數\(\hat{\theta}^k\)計算,最終的優化結果覆蓋掉原來的引數\(\theta^k\)。
如果我們將一二步合在一起看,則可以視為通過下面目標函數來一起優化關於引數\(\theta^k\)的內層目標函數和元目標函數:
最後,一旦本地訓練完成,則來自所有使用者端的本地引數\(\theta^k\)會被伺服器聚合並更新全域性模型。
本篇論文屬於基於學習領域不變表徵的域泛化方法,並通過使所有使用者端的表徵對齊一個由GAN自適應生成的參考分佈,而非使使用者端之間的表徵互相對齊,來保證聯邦場景下的資料隱私性。本文方法整體的架構如下圖所示:
注意,這裡所有使用者端共用一個參考分佈,而這通過共用同一個分佈生成器(distribution generator)來實現。在訓練過程一邊使每個域(使用者端)的資料分佈會和參考分佈對齊,一邊最小化分佈生成器的損失函數,使其產生的參考分佈接近所有源資料分佈的「中心」(這也就是」自適應「的體現)。一旦判別器很難區分從特徵提取器中提取的特徵和從分佈生成器中所生成的特徵,此時所提取的特徵就被認為是跨多個源域不變的。這裡的特徵分佈生成器的輸入為噪聲樣本和標籤的one-hot向量,它會按照一定的分佈(即參考分佈)生成特徵。最後,作者還採用了隨機投影層來使得判別器更難區分實際提取的特徵和生成器生成的特徵,使得對抗網路更穩定。在訓練完成之後,參考分佈和所有源域的資料分佈會對齊,此時學得的特徵表徵被認為是通用(universal)的,能夠泛化到未知的領域。
接下來我們來看GAN部分具體的細節。設\(F(\cdot)\)為特徵提取器,\(G(\cdot)\)為分佈生成器,\(D(\cdot)\)為判別器。設由特徵提取器所提取的特徵\(\mathbf{h} = F(\mathbf{x})\)(資料\(\mathbf{x}\)的生成分佈為\(p(\mathbf{h})\)),而由分佈生成器所產生的特徵為\(\mathbf{h}'= G(\mathbf{z})\)(噪聲\(\mathbf{z}\)的生成分佈為\(p(\mathbf{h}')\)。我們設特徵提取器所提取的特徵為負例,生成器所生成的特徵為正例。
於是,我們可以將判別器的優化目標定義為使將特徵提取器所生成的特徵\(\mathbf{h}\)判為正類的概率\(D(\mathbf{h}|\mathbf{y})\)更小,而使將生成器所生成的特徵\(\mathbf{h}'\)判為正類的概率\(D(\mathbf{h}'|\mathbf{y})\)更大。
生成器儘量使判別器\(D(\cdot)\)將其生成特徵\(\mathbf{h}'\)判別為正類的概率\(D\left(\mathbf{h}^{\prime} \mid \mathbf{y}\right)\)更大,以求以假亂真:
特徵提取器也需要儘量使得其所生成的特徵\(\mathbf{h}\)能夠以假亂真:
再加上影象分類本身的交叉熵損失\(\mathcal{L}_{err}\),則總的損失定義為:
論文的最後,作者還對一個問題進行了探討:關於這裡的參考分佈,我們為什麼不用一個預先選好的確定的分佈,要用一個自適應生成的分佈呢?那是因為自適應生成的分佈有一個重要的好處,那就是少對齊期間的失真(distortion)。作者對多個域/使用者端的分佈和參考分佈進行了視覺化,如下圖所示:
(a)中為參考分佈選擇為固定的分佈後,與各域特徵對比的示意圖,圖(b)為參考分佈選擇為自適應生成的分佈後,和各域特徵對比的示意圖。在這兩幅圖中,紅色五角星表示參考分佈的特徵,除了五角星之外的每種形狀代表一個域,每種顏色代表一個類別的樣本。可以看到自適應生成的分佈和多個源域資料分佈的距離,相比固定參考分佈和多個源域資料分佈的距離更小,因此自適應生成的分佈能夠減少對齊期間提取特徵表徵的失真。而更好的失真也就意味著源域資料的關鍵資訊被最大程度的保留,這讓本文的方法所得到的表徵擁有更好的泛化表現。
本篇論文屬於基於學習領域不變表徵的域泛化方法,並通過使所有使用者端的表徵對齊一個高斯參考分佈,而非使使用者端之間的表徵互相對齊,來保證聯邦場景下的資料隱私性。本文的動機源於經典機器學習演演算法的思想,旨在學習一個「簡單」(simple)的表徵從而獲得更好的泛化效能。
首先,作者以生成模型的視角,將表徵\(z\)建模為從\(p(z|x)\)中的取樣,然後在此基礎上定義領域\(k\)的分類目標函數以學得表徵:
這裡領域\(k\)的樣本表徵\(z_j^{(i)}\)通過編碼器+重引數化從\(p(z|x_k^{(i)})\)中取樣產生。
接下來我們來看怎麼使得表徵更「簡單」。本文采用了兩個正則項,一個是關於表徵的\(L2\)正則項來限制表徵中所包含的資訊;一個是在給定\(y\)的條件下,\(x\)與\(z\)的條件互資訊\(I(x, z\mid y)\)(的上界)來使表徵只學習重要的資訊,而忽視諸如圖片背景之類的偽相關性(spurious correlations)。
關於表徵\(z\)的\(L2\)正則項定義如下:
於是,上式的微妙之處在於可以和領域不變表徵聯絡起來,事實上我們有\(\mathcal{L}_k^{L 2 R}=\mathbb{E}_{p_k(x)}\left[\mathbb{E}_{p(z \mid x)}\left[\|z\|_2^2\right]\right]=\mathbb{E}_{p_k(x, z)}\left[\|z\|_2^2\right]=\mathbb{E}_{p_k(z)}\left[\|z\|_2^2\right]=2 \sigma^2 \mathbb{E}_{p_k(z)}[-\log q(z)]=2 \sigma^2 H\left(p_k(z), q(z)\right)\),這裡\(H\left(p_k(z), q(z)\right)=H\left(p_k(z)\right)+ D_{\text{KL}} \left[p_k(z) \Vert q(z)\right]\),參考分佈\(q(z)=\mathcal{N}\left(0, \sigma^2 I\right)\)。如果\(H(p_i(z))\)在訓練中並未發生大的改變,那麼最小化\(l_k^{L2R}\)也就是在最小化\(D_{\text{KL}}[p_k(z) \Vert q(z)]\),也即在隱式地對齊一個參考的邊緣分佈\(q(z)\),而這就使得標準的邊緣分佈\(p_k(z)\)是跨域不變的。注意該對齊是不需要顯式地比較不同使用者端分佈的。
接下來我們來看條件互資訊項。在資訊瓶頸理論中,常對\(x\)和表徵\(z\)之間的互資訊項\(I(x, z)\)進行最小化以對\(z\)中所包含的資訊進行加以正則,但是這樣的約束在實踐中如果係數沒調整好,就很可能過於嚴格了,畢竟它迫使表徵不包含資料的資訊。因此,在這篇論文中,作者選擇最小化給定\(y\)時\(x\)和\(z\)之間的條件互資訊。領域\(k\)的條件互資訊被計算為:
直觀地看,\(\bar{f}_k\)和\(I_k(x, z\mid y)\)共同作用,迫使表徵\(z\)僅僅擁有預測標籤\(y\)使所包含的資訊,而沒有關於\(x\)的額外(即和標籤無關的)資訊。
然而,這個互資訊項是難解(intractable)的,這是由於計算\(p_k(z|y)\)很難計算(由於需要對\(x\)進行積分將其邊緣化消掉)。因此,作者匯出了一個上界來對齊進行最小化:
這裡\(r(z|y)\)可以是一個輸入\(y\)輸出分佈\(r(z|y)\)的神經網路,作者將其設定為高斯\(\mathcal{N}\left(z ; \mu_y, \sigma_y^2\right)\),這裡\(u_y\),\(\sigma^2_y\)(\(y=1, 2, \cdots, C\))是需要優化的神經網路引數,\(C\)是類別數量。
事實上,該正則項和域泛化中的條件分佈對齊亦有著理論上的聯絡,這是因為\( \mathcal{L}_k^{C M I}=\mathbb{E}_{p_k(x, y)}[D_{\text{KL}}[p(z \mid x) \Vert r(z \mid y)]] \geq \mathbb{E}_{p_k(y)}\left[D_{\text{KL}}\left[p_k(z \mid y) \Vert r(z \mid y)\right]\right] \)。因此,最小化\(\mathcal{L}_k^{CMI}\)我們必然就能夠最小化\(D_{\text{KL}}\left[p_k(z \mid y) \Vert r(z \mid y)\right]\)(因為\(\mathcal{L}^{CMI}_k\)是其上界),使得\(p_k(z|y)\)和\(r(z|y)\)互相接近,即:\(p_k(z|y)\approx r(z|y)\)。因此,模型會嘗試迫使\(p_k(z \mid y) \approx p_l(z \mid y)(\approx r(z \mid y))\)(對任意使用者端/領域\(k, l\))。這也就是說,我們是在做給定標籤\(y\)時表徵\(z\)的條件分佈的隱式對齊,這在傳統的領域泛化中是一種很常見與有效的技術,區別就是這裡不需要顯式地比較不同使用者端的分佈。
最後,每個使用者端的總體目標函數可以表示為:
總結一下,這裡\(L2\)範數正則項\(\mathcal{L}_k^{L2R}\)和給定標籤時資料和表徵的條件互資訊\(\mathcal{L}_k^{CMI}\)(的上界)用於限制表徵中所包含的資訊。此外,\(\mathcal{L}_k^{L2R}\)將邊緣分佈\(p_k(z)\)對齊到一個聚集在0周圍的高斯分佈,而\(\mathcal{L}_i^{CMI}\)則將條件分佈\(p_k(z|y)\)對齊到一個參考分佈(在實驗環節作者亦將其選擇為高斯)。
本篇論文屬於基於資料操作的域泛化方法,並通過構造一個style bank供所有使用者端共用(類似CVPR21那篇),以使使用者端在不共用資料的條件下基於風格(style)來進行資料增強,從而保證聯邦場景下的資料隱私性。本文方法整體的架構如下圖所示:
如圖所示,每個client的資料集都有自己的風格。且對於每個使用者端而言,都會接受其餘使用者端的風格來進行資料增強。事實上,這樣就可以使得分散式的使用者端在不洩露資料的情況下擁有相似的資料分佈 。在本方法中,所有使用者端的本地模型都擁有一致的學習目標——那就是擬合來自於所有源域的styles,而這種一致性就避免了本地模型之間的模型偏差,從而避免了影響全域性模型的效果。此外,本方法可和其它DG的方法結合使用,從而使得其它中心化的DG方法均能得到精度的提升。
關於本文采用的風格遷移模型,有下列要求:1、所有使用者端共用的style不能夠被用來對資料集進行重構,從而保證資料隱私性;2、用於風格遷移的方法需要是一個實時的任意風格遷移模型,以允許高效和直接的風格遷移。本文最終選擇了AdaIN做為原生的風格遷移模型。整個跨使用者端/領域風格遷移流程如下圖所示:
可以看到,整個跨使用者端/領域風格遷移流程分為了三個階段:
1. Local style Computation
每個使用者端需要計算它們的風格並上傳到全域性伺服器。其中可選擇單張圖片風格(single image style)和整體領域風格(overall domain style )這兩種風格來進行計算。
如果單張圖片風格被用於風格遷移,那麼就需要將該使用者端不同圖片對應的多種風格都上傳到伺服器,從而避免單張圖片的偏差並增加多樣性。而這就需要建立本地圖片的style bank \(\mathcal{S}_k^{single}\)並將其上傳到伺服器。這裡作者隨機選擇\(J\)張影象的style加入了本地style bank:
相比單張圖片風格,整體領域風格的計算代價非常高。不過,由於每個使用者端/領域只有一個領域風格\(S_k^{overall}\),選擇上傳整體領域風格到伺服器的通訊效率會更高。
2. Style Bank on Server
當伺服器接收到來自各個使用者端的風格時,它會將所有風格彙總為一個style bank \(\mathcal{B}\) 並將其廣播回所有使用者端。在兩種不同的風格共用模式下,style bank亦會有所不同。
\(\mathcal{B}_{single}\)比\(\mathcal{B}_{overall}\)會消耗更多儲存空間,因此後者會更加通訊友好。
3. Local Style Transfer
當用戶端\(k\)收到style bank \(\mathcal{B}\)後,本地資料會通過遷移\(\mathcal{B}\)中的風格來進行增強,而這就將其它領域的風格引入了當前使用者端。作者設定了超引數\(L \in\{1,2, \ldots, K\}\)做為增強級別,意為從style bank \(\mathcal{B}\)中隨機選擇\(L\)個域所對應的風格來對每個圖片進行增強,因此\(L\)表明了增強資料集的多樣性。設第\(k\)個使用者端資料集大小為\(N_k\),則在進行跨使用者端的領域遷移之後,增強後資料集的大小會變為\(N_k \times L\)。其中對使用者端\(k\)中的每張圖片\(I^{(i)}_k\),其對應的每個被選中的域都會擁有一個style vector\(S\)被作為影象生成器\(G\)的輸入。這裡關於style vector的獲取有個細節需要注意:假設我們選了域\(k\),如果遷移的是整體領域風格,則\(S^{overall}_k\)直接即可做為style vector;如果遷移的是單圖片風格,則還會進一步從選中\(\mathcal{S}^{single}_k\)中隨機選擇一個風格\(S_k^{(i)}\)做為域\(k\)的style vector。對以上兩種風格模式而言,如果一個域被選中,則其對應的風格化圖片就會被直接加入增強後的資料集中。