在 TPU 上調度模型資源

這份文件提供語言模型擴充的總覽,包括 TPU 的運作方式和彼此之間的溝通方式、大型語言模型在實際硬體上的執行方式,以及如何在訓練和推論期間平行處理模型,讓模型在大規模運作時也能保持高效率。我們提供相關資訊,協助您評估訓練 LLM 的成本、服務模型所需的記憶體容量,以及如何在多個 TPU 中有效分割模型。

雖然許多深度學習技術都很複雜,但即使大規模使用,最佳化模型效能也不必如此。基本原則適用於任何地方,無論是處理單一加速器還是數以萬計的加速器,瞭解這些原則都能讓您執行許多實用操作:

  • 估算模型各部分與理論最佳值的接近程度。
  • 在不同規模下,針對不同平行處理配置做出明智選擇 (如何在多部裝置之間分割運算)。
  • 估算訓練及執行大型 Transformer 模型所需的費用和時間。
  • 設計可充分運用 TPU 架構的演算法。
  • 根據對演算法效能限制的明確瞭解,設計模型架構。

必要條件

您應基本瞭解 LLM 和 Transformer 架構,但不一定要瞭解這些模型的大規模運作方式。您應瞭解 LLM 訓練的基本概念,最好也對 JAX 有基本認識。如要瞭解 Transformer 架構的實用背景資訊,請參閱下列文章:

熟悉這些必要條件後,您應該就能輕鬆估算特定 TPU 平台上的 Transformer 模型最佳平行處理配置。您也可以估算訓練和推論所需的時間。

模型擴展的重要性

如今,大型語言模型和大多數小型模型都接近硬體限制,因此開發模型時,您必須考慮大規模效率。如果基準的勝率提高 20%,但屋頂線效率卻降低 20%,那麼這項進展就沒有意義。有前景的模型架構經常會失敗,不是因為無法有效率地大規模執行,就是因為缺乏最佳化工作,無法達成這個目標。

模型擴充的目標是增加用於訓練或推論的晶片數量,同時實現吞吐量成比例的線性增長。這就是所謂的強效調度。雖然新增晶片 (平行處理) 通常會縮短運算時間,但也會增加晶片間的通訊量。如果通訊時間長於運算時間,模型就會受到通訊限制,無法順利擴充。充分瞭解硬體,預測這些瓶頸會出現在何處,有助於設計或重新設定模型,避免這些瓶頸。

以下各節將概略說明如何擴充 TPU 硬體,以及 Transformer 架構的演進歷程。這項資訊對設計新架構的研究人員,以及致力於讓當代 LLM 快速執行的工程師來說,都非常實用。

第 1 部分:概念

本節說明屋頂線分析,以及限制模型擴展能力 (通訊、運算和記憶體) 的因素。接著,我們將說明 TPU 的運作方式,包括個別晶片,以及 (至關重要) 頻寬和延遲有限的晶片間連結互連系統。

第 2 部分:擴大 Transformer 的規模

請務必瞭解 Transformer 架構的每個部分:每個矩陣的確切大小、發生正規化的位置,以及每個部分有多少參數和 FLOP。本節將仔細介紹 Transformer 數學,說明如何計算訓練和推論的參數和 FLOP。這項資訊會顯示模型使用的記憶體量、運算或通訊所花費的時間,以及相對於前饋區塊,注意力何時會變得重要。

最後,這個部分有助於解答基本問題:假設模型大小固定,且提供一定數量的晶片,如何平行化模型以維持強大的資源調度條件。為回答這個問題,本節將探討四種主要平行處理技術,這些技術可用於將模型分割到多個晶片上:資料、張量、管道和專家。此外,本文也說明其他可減少記憶體需求的技術,例如重新具體化、以 ZeRO 為基礎的模型分片、主機卸載和梯度累積。

  • Transformer 數學運算簡介:本節將逐步說明數學運算,回答有關 Transformer 在正向和反向傳遞期間使用的 FLOP 數量、計算參數數量,以及 KV 快取大小的問題。
  • 訓練用的 Transformer 平行化:本節將詳細說明如何協調 FSDP、Megatron 分片和管道平行化,盡可能提高訓練效率。說明如何為特定模型大小和批量,在固定數量的晶片間決定最佳分配方式,以達到尖峰總處理量。
    • 在 TPU 上訓練 Llama 3:這個小節說明如何在 TPU 上訓練 Llama 3、可能需要多少時間,以及可能花費多少費用。
  • 推論的 Transformer 擴縮:模型訓練完成後,需要提供服務。推論會新增延遲這個考量,並改變記憶體環境。本節說明分散式服務的運作方式,以及如何考量 KV 快取。
    • 在 TPU 上提供 Llama 3:本小節說明如何在 TPU 上提供 Llama 3、可能產生的費用,以及延遲和輸送量的取捨。

第 3 部分:實際導入

本節說明如何使用 JAX 實作縮放概念,以及在發生錯誤時如何剖析及偵錯程式碼。

  • 分析 TPU 程式: 實際的大型語言模型很複雜,難以開發、最佳化及偵錯。 本節說明 JAX + XLA 堆疊,以及如何使用 JAX/TensorBoard 分析器偵錯及修正實際問題。
  • 在 JAX 中為 TPU 編寫程式: 本節說明如何使用 JAX API 平行化運算。