TPU でモデルをスケーリングする
このドキュメントでは、言語モデルをスケーリングする方法の概要について説明します。TPU の仕組みや相互通信の仕組み、実際のハードウェアでの LLM の動作、大規模なスケールで効率的に実行できるようにトレーニングと推論中にモデルを並列処理する方法などについて説明します。LLM のトレーニングにかかる費用、モデルのサービングに必要なメモリ量、複数の TPU 間でモデルを効果的にシャーディングする方法を評価するのに役立つ情報を提供します。
ディープ ラーニングの多くは複雑ですが、モデルのパフォーマンスの最適化は、大規模な場合でも複雑である必要はありません。基本的な原則は、1 つのアクセラレータから数万個のアクセラレータまで、あらゆる場所に適用されます。これらの原則を理解することで、多くの有用なことを行うことができます。
- モデルの各部分が理論上の最適値にどれだけ近いかを推定します。
- さまざまなスケールでさまざまな並列処理スキームについて十分な情報を得たうえで選択する(複数のデバイスに計算を分割する方法)。
- 大規模な Transformer モデルのトレーニングと実行に必要な費用と時間を推定します。
- TPU アーキテクチャを活用するアルゴリズムを設計します。
- アルゴリズムのパフォーマンスを制限する要因を明確に理解したうえで、モデル アーキテクチャを設計します。
前提条件
LLM と Transformer アーキテクチャの基本的な知識は必要ですが、大規模な運用方法に関する知識は必ずしも必要ありません。LLM トレーニングの基本を理解している必要があります。また、JAX の基本をある程度理解していることが望ましいです。Transformer アーキテクチャの背景を理解するうえで役立つ資料は次のとおりです。
- The Illustrated Transformer: Transformer アーキテクチャに関するブログ投稿
- Attention is all you need: Transformer の元論文
これらの前提条件を理解したら、特定の TPU プラットフォームで Transformer モデルに最適な並列処理スキームを推定できるようになります。トレーニングと推論にかかる時間も推定できます。
モデルのスケーリングの重要性
現在の LLM とほとんどの小規模モデルはハードウェアの限界に近い状態で実行されるため、モデルを開発するには、大規模な効率性を考慮する必要があります。ベンチマークで 20% の改善があっても、ルーフライン効率が 20% 低下するなら意味がありません。有望なモデル アーキテクチャは、大規模な環境で効率的に実行できないか、効率的に実行するための最適化が不足しているため、日常的に失敗します。
モデル スケーリングの目標は、トレーニングまたは推論に使用するチップの数を増やしながら、スループットを比例的に線形に増やすことです。これは、強いスケーリングと呼ばれます。通常、チップを追加(並列処理)すると計算時間は短縮されますが、チップ間の通信が増加するというデメリットもあります。通信に計算よりも時間がかかると、モデルは通信バウンドになり、スケーリングがうまくいきません。ハードウェアを十分に理解して、ボトルネックが発生する場所を予測することで、ボトルネックを回避するようにモデルを設計または再構成できます。
以降のセクションでは、TPU ハードウェアをスケーリングする方法と、Transformer アーキテクチャの進化について説明します。この情報は、新しいアーキテクチャを設計する研究者と、現世代の LLM を高速で実行するエンジニアの両方にとって有用です。
パート 1: コンセプト
このパートでは、ルーフライン分析と、モデルのスケーリング能力を制限する要因(通信、計算、メモリ)について説明します。次に、TPU が個々のチップとして、また、帯域幅とレイテンシが制限されたチップ間リンクを備えた相互接続システムとしてどのように機能するかを説明します。
- ルーフライン分析の概要: このセクションでは、コンピューティング、通信、メモリの上限に基づいてアルゴリズムの実行速度を概算する方法について説明します。
- TPU アーキテクチャのオペレーション: このセクションでは、TPU のアーキテクチャ、TPU のさまざまなハードウェア モジュールの動作、モデルのトレーニングとサービングへの影響について説明します。
- マルチ TPU 並列処理のモデル シャーディング: このセクションでは、シャーディングされた行列乗算について説明し、モデル シャーディングとマルチ TPU 並列処理について詳しく説明します。
パート 2: Transformer のスケーリング
Transformer アーキテクチャのすべての部分(各行列の正確なサイズ、正規化が行われる場所、各部分のパラメータ数と FLOP 数)を理解することが重要です。このパートでは、この Transformer の数学を詳しく説明し、トレーニングと推論の両方でパラメータと FLOP をカウントする方法を示します。これにより、モデルが使用するメモリ量、コンピューティングまたは通信に費やす時間、フィードフォワード ブロックに対して注意が重要になるタイミングを確認できます。
最後に、このパートでは、特定のサイズのモデルと特定の数のチップが与えられた場合に、強いスケーリング条件を維持するためにモデルを並列化する方法という基本的な質問に対する答えを得るのに役立ちます。この質問に答えるため、このパートでは、モデルを複数のチップに分割するために使用される 4 つの主要な並列処理手法(データ、テンソル、パイプライン、エキスパート)について説明します。また、再実体化、ZeRO を利用したモデル シャーディング、ホスト オフロード、グラデーションの累積など、メモリ要件を削減する他の手法についても説明します。
- Transformer の数学演算の概要: このセクションでは、Transformer がフォワード パスとバックワード パスで使用する FLOP の数、パラメータの数を計算する計算、KV キャッシュのサイズに関する質問に答えるための数学演算について説明します。
- トレーニング用の Transformer 並列化: このセクションでは、FSDP、Megatron シャーディング、パイプライン並列処理を調整してトレーニング効率を最大化するプロセスについて詳しく説明します。特定のモデルサイズとバッチサイズで、固定数のチップに最適な分布を決定して、ピーク スループットを実現する方法について説明します。
- TPU での Llama 3 のトレーニング: このサブセクションでは、TPU で Llama 3 をトレーニングする方法、トレーニングにかかる時間、費用について説明します。
- 推論用の Transformer スケーリング: モデルのトレーニングが完了したら、モデルをサービングする必要があります。推論により、新しい考慮事項(レイテンシ)が追加され、メモリの状況が変化します。このセクションでは、分離型サービングの仕組みと KV キャッシュの考え方について説明します。
- TPU で Llama 3 をサービングする: このサブセクションでは、TPU で Llama 3 をサービングする方法、費用、レイテンシとスループットのトレードオフについて説明します。
パート 3: 実用的な実装
このパートでは、JAX を使用してスケーリングのコンセプトを実装する方法と、問題が発生した場合にコードのプロファイリングとデバッグを行う方法について説明します。
- TPU プログラムのプロファイリング: 実際の LLM は複雑で、開発、最適化、デバッグが困難です。このセクションでは、JAX + XLA スタックと、JAX/TensorBoard プロファイラを使用して実際の問題をデバッグして修正する方法について説明します。
- JAX での TPU のプログラミング: このセクションでは、JAX API を使用して計算を並列化する方法について説明します。