Lora升級!ReLoRa!最新論文 High-Rank Training Through Low-Rank Updates

2023-10-25 18:01:01

關注公眾號TechLead,分享AI與雲服務技術的全維度知識。作者擁有10+年網際網路服務架構、AI產品研發經驗、團隊管理經驗,同濟本復旦碩,復旦機器人智慧實驗室成員,阿里雲認證的資深架構師,專案管理專業人士,上億營收AI產品研發負責人。

摘要

儘管通過擴充套件導致具有數千億引數的大型網路在統治和效率方面表現突出,但訓練過引數化模型的必要性仍然難以理解,且替代方法不一定能使訓練高效能模型的成本降低。在本文中,我們探索了低秩訓練技術作為訓練大型神經網路的替代方法。我們引入了一種名為 ReLoRA 的新方法,該方法利用低秩更新來訓練高秩網路。我們將 ReLoRA 應用於預訓練最多達 350M 引數的變換器語言模型,並展示了與常規神經網路訓練相當的效能。此外,我們觀察到 ReLoRA 的效率隨著模型大小的增加而提高,使其成為訓練多十億引數網路的有效方法。我們的研究發現揭示了低秩訓練技術的潛力及其對擴充套件規律的影響。程式碼已在 GitHub 上提供。

1 引言

在過去的十年中,機器學習領域一直被訓練越來越多引數化的網路或採取「疊加更多層」的方法所主導。大型網路的定義已經從具有1億個引數的模型演變到數百億個引數,這使得與訓練這樣的網路相關的計算成本對大多數研究團隊來說變得過於昂貴。儘管如此,與訓練樣本相比,需要訓練數量級更多的引數的模型的必要性在理論上仍然理解不足。

例如更有效的計算擴充套件最佳化、檢索增強模型、以及通過更長時間訓練較小模型的簡單方法等替代擴充套件方法,都提供了新的權衡。然而,它們並沒有讓我們更接近理解為什麼我們需要過引數化的模型,也很少使這些模型的訓練民主化。例如,訓練RETRO需要一套複雜的訓練設定和基礎設施,能夠快速搜尋數萬億的標記,而訓練LLaMA-6B仍然需要數百個GPU。

相比之下,像零冗餘優化器、16位元訓練、8位元推斷和引數有效微調(PEFT)等方法在使大型模型更易存取方面發揮了關鍵作用。具體來說,PEFT方法使得在消費者硬體上微調十億規模的語言或擴散模型成為可能。這引發了一個問題:這些方法是否也能惠及預訓練?

一方面,預訓練正是允許對網路進行微小修改以使其適應新任務的步驟。Aghajanyan等人已經證明,預訓練網路越多,學習任務所需的更改的秩就越小。另一方面,多項研究已經證明了語言和視覺模型提取和利用的特徵的簡單性,以及它們的低固有維度。例如,變換器中的注意力模式通常呈現小秩,這已經被成功用於開發更高效的注意力變體。此外,訓練過程中也並不需要過引數化。彩票票據假說從經驗上證明,在初始化(或訓練早期)時,存在子網路 - 獲勝票據,當單獨訓練時可以達到整個網路的效能。

在本研究中,我們專注於低秩訓練技術,並介紹了ReLoRA,它使用低秩更新來訓練高秩網路。我們憑經驗證明ReLoRA執行高秩更新,並實現與常規神經網路訓練相似的效能。ReLoRA的組成部分包括神經網路的初始完全秩訓練(類似於Frankle等人),LoRA訓練,重新開始,鋸齒狀學習速率計劃,以及部分優化器重置。我們對ReLoRA在高達350M引數的變換器語言模型上的效果進行評估。我們選擇專注於自迴歸語言建模,因為這種方法在神經網路的大多數應用中已經展示了其通用性。最後,我們觀察到ReLoRA的效率隨著模型大小的增加而增加,使其成為有效訓練多十億引數網路的可行選擇。

本研究中的每個實驗均未使用超過8個GPU天的計算。

2 相關工作

縮放與效率 過引數化與神經網路的可訓練性和泛化之間的關係已經得到了廣泛的研究,但仍然是一個謎。此外,縮放法則展示了網路大小與其在各種模態之間的效能之間存在簡單而強烈的冪律依賴關係。這一發現不僅支援過引數化,而且還鼓勵對非常消耗資源的神經網路進行訓練。然而,彩票假設表明原則上可以最小化過引數化。具體來說,它表明在訓練初期存在可以訓練以達到整個網路效能的子網路(中獎彩票)。

引數高效微調 Aghajanyan等人發現預訓練減少了網路的變化量或其固有維數,以通過微調學習新任務。即,更大的網路或在更多資料上預訓練的網路在學習新任務時需要較小的修改,就其範圍的秩而言。這解釋了引數高效微調方法的成功,並且還激發了像LoRA和Compacter這樣的低秩微調方法的發展。

低秩神經網路訓練 在CNN壓縮、正則化和高效訓練的背景下已經探討了訓練低秩表示。然而,這些方法中的大多數要麼特定於CNN,要麼不具備良好的可延伸性,要麼沒有在具有數億引數的大型轉換器上進行評估,而這些轉換器可以從高效訓練中大大受益。雖然轉換器已被證明具有低秩的內部維數和表示,但Bhojanapalli等人的研究表明,在多頭注意力中關鍵和查詢投影的低秩限制了轉換器的效能。我們自己的實驗(第3節)也表明,與完整秩基線和ReLoRA相比,低秩轉換器的效能明顯較差。

3 方法

讓我們從重新審視線性代數101開始。特別是,我們對兩個矩陣之和的秩感興趣:
rank(A + B) ≤ rank(A) + rank(B)。(1)
對和的秩的這個界限是緊的:對於矩陣A,有rank(A) < dim(A),存在B,使得rank(B) < dim(B),並且矩陣之和的秩高於A或B。我們想要利用這個屬性來製造一種靈活的引數高效的訓練方法。我們從LoRA開始,它是一種基於低秩更新思想的引數高效微調方法。LoRA可以應用於任何通過W ∈ R^m×n引數化的線性操作。具體來說,LoRA將權重更新δW分解為低秩乘積WAWB,如方程2所示,其中s ∈ R是通常等於1/r的固定縮放因子。
δW = sWAWB
WA ∈ R^in×r
, WB ∈ R^r×out(2)
在實踐中,LoRA通常是通過新增新的可訓練引數WA和WB來實現的,這些引數可以在訓練後合併回原始引數。因此,即使方程1允許在訓練時間P_t δWt內的總更新具有高於任何單個矩陣的更高的秩,LoRA實現也受到秩r = maxWA,WB rank(WAWB)的限制。

如果我們可以重新啟動LoRA,即在訓練期間合併WA和WB並重置這些矩陣的值,我們可以增加更新的總秩。多次這樣做將整個神經網路更新帶到
∆W = ΣT1_t=0 δWt + ΣT2_t=T1 δWt + · · · + ΣTN_t=TN−1 δWt = sW1_AW1_B + sW2_AW2_B + · · · + sWN_AWN_B(3)
其中,總和是獨立的,意味著rank(Wi_AWi_B) + rank(Wj_AWj_B) ≥ r。然而,在實踐中實現重新啟動並不是微不足道的,需要對優化過程進行一些修改。天真的實現會導致模型在重新啟動後立即發散。與僅依賴於當前優化時間步的梯度值的普通隨機梯度下降不同,Adam更新主要由之前步驟累積的梯度的第一和第二時刻指導。在實踐中,梯度矩滑引數β1和β2通常非常高,即0.9 - 0.999。假設在重新初始化邊界W1_A和相應的梯度矩mA和vA處是全秩的(r)。那麼,在合併和重新初始化後,繼續使用W2_A的舊梯度矩將引導它沿著W1_A的相同方向,並優化相同的子空間。

為了解決這個問題,我們提出了ReLoRA。ReLoRA在合併和重新初始化期間對優化器狀態進行部分重置,並將學習率設定為0,並隨後進行熱啟動。具體來說,我們將99%的低幅度優化器狀態值設定為零,並使用鋸齒狀餘弦學習率計劃(圖2)。我們的消融研究(表3)表明,這兩項修改都是提高LoRA效能的必要條件。

重申一下,ReLoRA是一種受LoRA啟發的低秩訓練方法,通過重新啟動來增加更新的有效秩,使用部分優化器重置和鋸齒排程器來穩定訓練和熱啟動。所有這些都使ReLoRA能夠通過一次僅訓練一小部分引數實現與全秩訓練相當的效能,特別是在大型變換器網路中。ReLoRA在演演算法1中描述。

提高計算效率 與其他低秩訓練技術不同,ReLoRA通過保持原始網路的凍結權重並新增新的可訓練引數來遵循LoRA方法。乍一看,這似乎在計算上是低效的;然而,凍結和可訓練引數之間的區別在引數高效微調中起到了關鍵作用。

這些方法通過減小梯度和優化器狀態的大小,顯著提高了訓練時間和記憶體效率。值得注意的是,Adam狀態消耗的記憶體是模型權重的兩倍。此外,對於大型網路,通常的做法是以32位元精度保持梯度累積緩衝區,從而增加了梯度的記憶體消耗的重要開銷。

通過大幅減少可訓練引數的數量,ReLoRA使得能夠使用更大的批次大小,最大化硬體效率。此外,它還減少了分散式設定中的頻寬要求,這通常是大規模訓練的限制因素。

此外,由於凍結引數在重新啟動之間沒有被更新,所以它們可以保持在低精度量化格式中,進一步減少它們的記憶體和計算影響。

這一額外的優化有助於整體提高ReLoRA在記憶體利用和計算資源方面的效率,並在規模上增加。

4 實驗

為了評估ReLoRA的有效性,我們將其應用於使用各種模型大小:60M、130M、250M和350M,在C4資料集上訓練變換器語言模型。語言建模已被證明是機器學習的基本任務,它能夠實現文字和影象分類、翻譯、程式設計、上下文學習、逐步推理等許多其他新興能力。鑑於其重要性,本文的目的僅關注語言建模。

架構和訓練超引數 我們的架構基於變換器,並與LLaMA非常相似。具體來說,我們使用預歸一化、RMSNorm、SwiGLU啟用、全連線隱藏狀態大小,以及旋轉嵌入。對於所有LoRA和ReLoRA實驗,我們使用秩r = 128,因為我們的初步實驗顯示它具有最佳的困惑度/記憶體權衡。所有超引數均在表1中呈現。

我們對所有浮點操作使用bfloat16,並使用Flash注意力進行有效的注意力計算。與LLaMA中使用float32進行softmax計算的注意力相比,這增加了50-100%的訓練吞吐量,而沒有任何訓練穩定性問題。

我們大部分模型在8個RTX 4090上訓練了一天或更短的時間。由於計算限制,我們訓練的模型要比LLaMA小得多,最大的模型擁有350M個引數,與BERT Large相同。我們根據Chinchilla縮放定律為所有模型選擇預訓練令牌的數量,除了最大的一個,我們為其訓練了6.8B個令牌,而9.5B個令牌是Chinchilla最優的。

ReLoRA和基線設定 在我們的低秩訓練實驗中,ReLoRA替換了所有注意力和全連線網路引數,同時保持嵌入全秩。RMSNorm引數化保持不變。由於ReLoRA封裝的模型比全秩訓練具有更少的可訓練引數,因此我們包括了一個控制基線,即具有與ReLoRA相同數量可訓練引數的全秩變換器。

我們從全秩訓練的5,000次更新步驟的檢查點開始初始化ReLoRA,並在此後的每5,000步重置一次,總共3次。每次重置後,基於大小修剪99%的優化器狀態,並在接下來的100次迭代中預熱損失。ReLoRA引數按照LoRA的最佳實踐重新初始化,A矩陣使用Kaiming初始化,B矩陣使用零。如果不使用重新啟動,B矩陣也使用Kaiming初始化以避免梯度對稱性問題。



5 結果

引數高效的預訓練 我們的主要結果在表2中展示。ReLoRA顯著優於低秩LoRA訓練,展示了我們所提出修改的有效性(在第3節中剖析)。此外,ReLoRA的表現與全秩訓練相似,且隨著網路大小的增加,效能差距逐漸減小。

通過低秩更新進行高秩訓練 為了確定ReLoRA是否執行比LoRA更高的秩更新,我們繪製了ReLoRA、LoRA和全秩訓練的熱啟動權重與最終權重之間差異的奇異值譜圖。圖3描繪了LoRA和ReLoRA在WQ、WK、WV和Wdown的奇異值之間的顯著定性差異。

雖然LoRA的大部分奇異值為零(圖4),且有顯著數量的異常高值超過1.5,但ReLoRA在0.1和1.0之間呈現更高的分佈質量,讓人聯想到全秩訓練。這一觀察強調了高秩更新的重要性,並展示了ReLoRA的定性功效,其通過執行多個低秩更新實現高秩更新。

5.1 剖析研究

我們對ReLoRA的四個關鍵元件:重啟、鋸齒狀排程、優化器重置和溫暖啟動進行剖析研究,使用130M大小的模型。結果展示在表3中。在本節中,我們將重點關注和分析這些元件的某些組合。

LoRA ReLoRA,沒有上述元件,本質上等同於通過LoRA引數化訓練低秩網路。這種方法產生了極高的困惑度,表明簡單的矩陣分解與全秩訓練有顯著不同的訓練動態。

新增重啟和優化器重置 ReLoRA,沒有鋸齒狀排程和優化器重置,表現與LoRA相似,因為舊的優化器狀態將新初始化的引數強制進入與先前權重相同的子空間,限制了模型的容量。然而,用ReLoRA進行天真的優化器重置會導致模型發散。鋸齒狀排程有助於穩定訓練,並對混合物產生積極影響。在我們的初步實驗中,我們還觀察到,部分優化器重置和鋸齒狀排程器的組合允許更快的預熱,低至50步,而不是從頭開始初始化優化器時所需的數百步。

溫暖啟動 溫暖啟動顯示了最顯著的改進,使困惑度降低了近10點。為了調查預熱後訓練是否有助於損失,我們測量了預熱網路的困惑度,等於27.03。它優於所有低秩方法,除了我們最終的ReLoRA配方,但仍然顯示出與最終網路的顯著差異。這展示了早期訓練的重要性,類似於彩票假說與倒帶的概念。

6 結論

在本文中,我們研究了大型變換器語言模型的低秩訓練技術。我們首先檢查了簡單低秩矩陣分解(LoRA)方法的侷限性,並觀察到它在有效訓練高效能變換器模型方面存在困難。為解決這個問題,我們提出了一種名為ReLoRA的新方法,它利用秩的和性質通過多個低秩更新來訓練高秩網路。與彩票假說和倒帶相似,ReLoRA在轉變為ReLoRA之前採用全秩訓練的溫暖啟動。此外,ReLoRA引入了合併和重新初始化(重啟)策略、鋸齒狀學習速率排程器和部分優化器重置,這些共同增強了ReLoRA的效率,並使其更接近全秩訓練,特別是在大型網路中。隨著網路大小的增加,ReLoRA的效率提高,使其成為多十億規模訓練的可行候選方案。

我們堅信,低秩訓練方法的發展對於提高訓練大型語言模型和一般神經網路的效率具有很大的潛力。此外,低秩訓練還有潛力為深度學習理論的進步提供有價值的見解,有助於我們通過梯度下降理解神經網路的可訓練性以及在過引數化體系中的卓越泛化能力。

7 侷限性和未來工作

超越350M的擴充套件 由於計算資源有限,我們的實驗僅限於訓練多達350M引數的語言模型。然而,ReLoRA已經在此規模上展示了有希望的結果。不過,我們預計其真正的潛力將在1B+引數區域實現。此外,雖然350M模型勝過控制基線,但並未繼續縮小ReLoRA和全秩訓練之間的差距的趨勢。我們將這一現象歸因於次優的超引數選擇,這需要進一步研究。

此外,在60-350M的實驗中,儘管ReLoRA顯著減少了可訓練引數的數量,但我們並未觀察到對這種大小的網路在記憶體和計算方面的實質改進。為了評估我們當前實現在更大規模上的效率,我們訓練了1.3B引數的模型進行少量迭代,以估計ReLoRA的記憶體和計算改進。在這個規模下,我們觀察到30%的記憶體消耗減少和52%的訓練吞吐量增加。我們期望在更大的網路中觀察到相對全訓練基線的更大改進,因為ReLoRA的可訓練引數數量(與LoRA類似)相較於凍結引數的數量增加得要慢得多。ReLoRA的實現可以通過有效利用ReLoRA層的梯度檢查點、自定義反向函數和將凍結模型權重轉換為int8或int4量化格式[14]來進一步改進。

與其他低秩訓練方法的比較 早期的工作已經探索了許多低秩訓練方法與其他模型架構的組合[44,49,55]。我們的工作與這些早期努力有兩個方面的不同。首先,我們提出的方法通過低秩訓練執行高秩更新。其次,我們的工作展示了在具有100M+引數的大規模變換器語言模型中,低秩訓練方法的競爭力。

關注公眾號TechLead,分享AI與雲服務技術的全維度知識。作者擁有10+年網際網路服務架構、AI產品研發經驗、團隊管理經驗,同濟本復旦碩,復旦機器人智慧實驗室成員,阿里雲認證的資深架構師,專案管理專業人士,上億營收AI產品研發負責人。
如有幫助,請多關注
TeahLead KrisChang,10+年的網際網路和人工智慧從業經驗,10年+技術和業務團隊管理經驗,同濟軟體工程本科,復旦工程管理碩士,阿里雲認證雲服務資深架構師,上億營收AI產品業務負責人。