使用 PyTorch FSDP 微調 Llama 2 70B

2023-12-12 06:00:20

引言

通過本文,你將瞭解如何使用 PyTorch FSDP 及相關最佳實踐微調 Llama 2 70B。在此過程中,我們主要會用到 Hugging Face Transformers、Accelerate 和 TRL 庫。我們還將展示如何在 SLURM 中使用 Accelerate。

完全分片資料並行 (Fully Sharded Data Parallelism,FSDP) 是一種訓練正規化,在該正規化中優化器狀態、梯度和模型引數都會被跨裝置分片。前向傳播時,每個 FSDP 單元執行 all gather 以獲取完整的權重,然後用它們進行計算並在計算後丟棄掉其他裝置的分片。隨後是反向傳播,然後就是損失計算。反向傳播時,每個 FSDP 單元執行 all gather 操作以獲取完整的權重,並執行計算以獲得本地 batch 的梯度。這些梯度通過 reduce scatter 在裝置上進行均值計算並分片,這樣每個裝置都可以更新其對應分片的引數。有關 PyTorch FSDP 的更多資訊,請參閱此博文: 使用 PyTorch 完全分片資料並行技術加速大模型訓練

(圖源: 連結)

使用的硬體

節點數: 2,至少 1 個節點
每節點 GPU 數: 8
GPU 型別: A100
GPU 視訊記憶體: 80GB
節點內互聯: NVLink
每節點記憶體: 1TB
每節點 CPU 核數: 96
節點間互聯: AWS 的 Elastic Fabric Adapter (EFA)

微調 LLaMa 2 70B 面臨的挑戰

在嘗試使用 FSDP 微調 LLaMa 2 70B 時,我們主要遇到了三個挑戰:

  1. FSDP 會先載入整個預訓練模型,然後再對模型進行分片。這樣就意味著節點內的每個程序 (即 rank) 都會載入整個 Llama-70B 模型,因此需要 7048 GB ~ 2TB 的 CPU 記憶體,這個算式中 4 是每個引數所需位元組數,8 是每個節點的 GPU 數。這會導致 CPU 記憶體不足,進而導致程序終止。
  2. 使用 FULL_STATE_DICT 來儲存完整中間檢查點並將其解除安裝至 rank 0 的 CPU 記憶體中需要花費大量時間,且由於在此期間通訊庫需要無限期掛起等待儲存完成,因此經常會導致 NCCL 超時錯誤。然而,完全關掉這個選項也不好,因為在訓練結束時我們需要儲存完整的模型狀態字典,而不是 FSDP 式分片的狀態字典。
  3. 我們需要提高速度並減少視訊記憶體使用,以加快訓練並節約計算成本。

下文,我們主要討論如何一一解決上述挑戰,最終微調出一個 70B 的模型!

先列出重現結果所需的所有資源:

  1. 程式碼庫: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training,程式碼中包含了使能 flash 注意力 V2 的熱修補程式
  2. FSDP 組態檔: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml
  3. SLURM 啟動指令碼 - launch.slurm : https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25
  4. 模型: meta-llama/Llama-2-70b-chat-hf
  5. 資料集: smangrul/code-chat-assistant-v1 (混合了 LIMA 和 GUANACO 資料集,且已轉換為訓練所需的格式)

準備工作

首先按照 此步驟 安裝 Flash Attention V2。然後,安裝最新的 PyTorch nightly (CUDA ≥11.8)。接著,根據 此檔案 安裝其餘依賴軟體。在本文中,我們是從主分支安裝