TPU7x(Ironwood)のパフォーマンスの最適化
このガイドでは、TPU7x(Ironwood)のパフォーマンスを最適化するために、マルチティア メモリ システム間のデータ移動を効率的に管理するいくつかの方法について説明します。これには、低精度トレーニング、シャーディング、通信の最適化、アクティベーションの再実体化、スコープ付き仮想メモリ チューニング、カスタム アクセラレータ カーネルなどの手法が含まれます。
TPU7x でパフォーマンスを最適化するには、まず Ironwood アーキテクチャ、特にメモリ階層と相互接続トポロジを理解する必要があります。詳細については、TPU7x(Ironwood)をご覧ください。
FP8 を使用した低精度トレーニング
FP8(8 ビット浮動小数点)は、主にモデルのトレーニングと推論を高速化するために使用される効率的な数値データ形式です。標準の 16 ビット形式(FP16 または BF16)や 32 ビット(FP32)ではなく、8 ビットを使用して数値を表すことで、TPU はデータを大幅に高速に処理し、メモリ使用量を削減できます。
TPU7x は、FP8 データ型の組み込みハードウェア アクセラレーションをサポートしており、チップあたり 4, 614 TFLOPS の理論上のピーク パフォーマンスを提供します。この機能により、エンドツーエンドのトレーニング時間を大幅に短縮できます。互換性のあるオペレーション(特に AI ワークロードで一般的な高密度行列乗算)では、FP8 を使用すると、標準の BF16 トレーニングよりも 1.3 倍のパフォーマンス向上が得られます。BF16 と比較して、FP8 はピーク FLOP が 2 倍になり、重みとアクティベーションのメモリ使用量が半分になります。FP8 は、コンピューティング バウンドのワークロードと、メモリ容量または帯域幅によって制約されるシナリオの両方で、主要なチューニング レバーとなります。
FP8 を使用すると、次のようなパフォーマンス上のメリットがあります。
- 高帯域幅メモリ(HBM)の負荷を軽減: メモリ フットプリントが小さくなるため、推論中に大きなモデルや大きな KV キャッシュを持つモデルを 192 GB の HBM に完全に収めることができます。これにより、低速のホストメモリへの高コストのオフロードを回避できます。
- 有効なバッチサイズの増加: FP8 は、アクティベーションに必要なメモリを削減することで、より大きなバッチサイズを使用できるようにします。これにより、データ並列処理が改善され、スループットの向上とコンピューティング ユニットの利用率の向上につながります。
- メモリ帯域幅の要件の低減: 各オペレーションで半分の量のデータを移動することで、HBM から MXU へのデータパスの需要が減少します。データ移動が一般的なボトルネックとなるシステムでは、これにより MXU のワークロードを飽和状態に保つことができます。
パフォーマンスの低下をゼロまたは最小限に抑えて FP8 を使用するには、量子化手法を慎重に選択する必要があります。FP8 トレーニングで考慮すべきベスト プラクティスをいくつかご紹介します。
- スケーリングの粒度: ベースラインとしてテンソルごとのスケーリングから始めます。品質やパフォーマンスに関する問題がある場合は、軸ごとのスケーリングに切り替えます。サブチャネルのスケーリングは不要になる可能性があります。
- スケーリング モード: ランタイムでスケーリング ファクタを計算する動的スケーリングは、品質を維持するための適切なデフォルトです。静的スケーリングは、計算を排除することでパフォーマンスを大幅に向上させることができますが、正しいスケーリング係数を決定するには慎重なプロファイリングが必要です。また、特にモデル構成が変更される場合など、すべてのユースケースに適しているとは限りません。逆に、一部の堅牢なモデルと構成では、重みまたはアクティベーションのスケールを FP8 の上限に固定できるため、精度を維持しながら量子化オーバーヘッドを削減し、パフォーマンスを向上させることができます。
- FP8 形式(E4M3 と E5M2): 一般的で効果的なアプローチは、FP8 形式を組み合わせて使用することです。たとえば、フォワード パスで重みとアクティベーションに E4M3 を使用して E4M3 の高精度を活用し、バックワード パスで勾配に E5M2 を使用して勾配の広いダイナミック レンジに対応します。
- 丸め: 勾配に確率的丸めではなく「最も近い偶数への丸め」(RNE)を使用すると、品質を維持しながら、パフォーマンスと再現性を向上させることができます。
- MaxText で FP8 を有効にする: MaxText は、QWIX 量子化ライブラリを介して FP8 トレーニングをサポートしています。量子化を有効にするには、構成で
use_qwix_quantization=trueフラグを設定します。
シャーディングと並列処理
シャーディングとは、大きなモデルまたはそのトレーニング データを小さな部分に分割し、複数の TPU チップまたはコアに分散するプロセスです。適切なシャーディング戦略を選択することは、TPU7x で高いパフォーマンスを実現するために重要です。
並列処理の度合いを純粋に最大化する単純なアプローチでは、通信バウンドになることでパフォーマンスが低下することがよくあります。通常は、メモリ制約を満たす最もシンプルなシャーディング戦略を選択するのが最善のアプローチです。これにより、通信オーバーヘッドが最小限に抑えられ、コンピューティング ユニットを効率的に使用できます。
シャーディング戦略を選択する前に、パフォーマンス チューニングの最初のステップとして算術強度分析を行う必要があります。この分析では、特定の計算がコンピューティング、メモリ帯域幅、相互接続帯域幅のいずれによって制限されているかが判断されます。これは、移動する必要があるデータのバイト数に対する浮動小数点演算の比率として計算されます。
算術演算の強度が高い場合は、コンピューティング バウンドのワークロードであることを示します。算術強度が低い場合は、メモリまたは通信バウンドのワークロードを示します。この場合、パフォーマンスは HBM から ICI ネットワーク経由でデータを移動できる速度によって制限されます。この分析により、理想的なバッチサイズとシャーディング戦略を把握できます。たとえば、通信バウンドのワークロードでは、高次テンソル並列処理など、通信をさらに増やすシャーディング戦略はメリットがありません。
シャーディング戦略の決定フレームワーク
MaxText には、さまざまなシャーディング戦略が用意されています。最適な選択は、モデル アーキテクチャ、シーケンス長、計算負荷と通信オーバーヘッドのバランスを取る必要性によって異なります。
- 完全にシャーディングされたデータ並列処理(FSDP): データ並列処理の推奨されるデフォルトの戦略です。FSDP は、モデルの重み、勾配、オプティマイザーの状態をデータ並列デバイス間でシャーディングします。計算中、各デバイスは All-Gather オペレーションを実行して、ローカル マイクロバッチに必要な完全な重みを取得します。FSDP は、デバイスごとのバッチサイズがこの All-Gather 通信のレイテンシを隠すのに十分な大きさであれば、非常に効果的です。Mixture-of-Experts(MoE)モデルの場合、算術強度の計算でスパース性を考慮する必要があります。
- テンソル並列処理(TP): TP は、デバイス間で個々のテンソルをシャーディングします。通常、テンソルは多層パーセプトロン(MLP)と注意ブロックの重み行列です。ハードウェアの算術演算密度が高い(11.5k)ため、ICI よりも TP を実現するには、モデルの次元に非常に高い要件が課せられます。TP を使用しようとすると、システムが通信バウンドになる可能性があります。
- エキスパート並列処理(EP): MoE モデルのトレーニングに必要な標準的な戦略です。EP は「エキスパート」レイヤをデバイスのセット全体にシャーディングし、オールツーオール通信コレクティブを使用してトークンを指定されたエキスパート デバイスにルーティングします。モデルの MLP ディメンションがルーフラインに近づくほど大きい場合、EP は効率的になります。
- コンテキスト並列処理(CP): CP は、非常に長いシーケンス長のモデルをトレーニングするために不可欠な特殊な戦略です。主な機能は、アクティベーションのメモリ使用量を管理することです。アクティベーションのメモリ使用量はシーケンス長とともに 2 次関数的に増加し、HBM 容量を超える可能性があります。CP はアクティベーション テンソルのシーケンス ディメンションをシャーディングするため、デバイスごとのバッチサイズを分数で使用できます。CP は FSDP よりも多くの通信を導入するため、一般的なルールとして、メモリ制約を満たし、バッチ軸シャードが整数であることを保証するために必要な最小限の CP を使用します。
次の表に、一般的なワークロード タイプと最適なシャーディング戦略のマッピングを示します。
| ワークロード タイプ | 推奨されるプライマリ シャーディング | セカンダリ シャーディング | 主なボトルネック | 根拠 |
|---|---|---|---|---|
| 高密度モデル - 短いシーケンス | FSDP | なし | 再実体化、FF Matmuls | FSDP は最適なバランスを提供します。短いシーケンスの場合、アクティベーション メモリはそれほど懸念する必要がないかもしれません。鍵は、FSDP の重み All-Gather を隠すのに十分な大きさのグローバル バッチです。バッチサイズが増加すると、アクティベーション サイズも増加します。この構成でメモリ不足が発生しないようにするには、適切な再実体化ポリシーが必要です。 |
| 高密度モデル - 長いシーケンス | FSDP | CP | Flash Attention、Activation Memory | アクティベーション メモリが主な制約になります。CP は、デバイスごとのバッチサイズの端数を有効にして、メモリ不足(OOM)の問題を回避するために必要です。フラッシュ アテンションは、コンピューティングと無駄な時間の主な原因です。 |
| MoE モデル - 短いシーケンス | FSDP + EP | なし | All-to-All(エキスパート ルーティング)、再実体化 | MoE モデルでは、EP がエキスパートをシャーディングする必要があります。トークン ルーティングの All-to-All 通信は、重複させる必要がある大きなボトルネックです。再マテリアル化も無駄の大きな原因です。 |
| MoE モデル - 非常に大規模 | FSDP + EP + PP | モデル並列処理(MP) | 前述のボトルネックとパイプライン バブル | 単一の Pod のメモリを超えるモデルの場合、PP は Pod 間でレイヤをシャーディングするために必要です。これにより、DCN 通信とパイプライン バブルのオーバーヘッドが発生します。これは非常に複雑な構成であり、慎重なチューニングが必要です。 |
通信の最適化
TPU7x で通信とコンピューティングをオーバーラップさせる主なメカニズムは、SparseCore Collective Offloading と呼ばれます。Ironwood アーキテクチャには、ICI ファブリック上のデータ移動を管理できる独立した制御スレッドとして機能する専用の SparseCore ユニットが含まれています。これにより、TensorCore で実行されるメインの計算と並行して、集団通信演算(All-Gather や Reduce-Scatter など)を実行できます。これは、TPU7x で非同期コレクティブを使用する場合に推奨される方法です。推奨フラグを使用して、最も一般的なコレクティブのオフロードを有効にします。
アクティベーションの再実体化
アクティベーションの再実体化(グラデーション チェックポイントとも呼ばれます)は、モデルの HBM フットプリントを削減するための基本的な手法です。フォワードパスからの中間アクティベーションをすべて HBM に保存してバックワードパスで使用するのではなく、いくつかのキー アクティベーション(チェックポイント)のみを保存し、バックワードパスで他のアクティベーションをオンデマンドで再計算します。これにより、計算量が増加する(標準の Transformer ブロックの場合、約 25 ~ 30% の追加の FLOP)代わりに、メモリを大幅に節約できます。
再実体化をどの程度積極的に適用するかという決定は、プライマリ ボトルネックに完全に依存する重要なチューニング パラメータです。プライマリ ボトルネックはシーケンス長によって異なることがよくあります。
長いシーケンス ワークロードの場合(128k など): この場合、アクティベーション テンソルのサイズが HBM の主な消費者になります。通常、ワークロードはメモリバウンドです。したがって、積極的な再実体化ポリシーを適用することは非常に有益です。メモリを節約することで、メモリ不足エラーが発生することなくトレーニングを進めることができ、バッチサイズを大きくすることもできます。再計算の計算オーバーヘッドは、価値のあるトレードオフです。
短シーケンス ワークロードの場合(8k など): このような場合、アクティベーション メモリはあまり問題になりません。ワークロードがコンピューティング バウンドになる可能性が高くなります。再実体化の計算オーバーヘッドは、非効率性の最大の原因となる可能性があります。
MaxText での再実体化ポリシーのチューニング
MaxText では、remat_policy フラグを使用して構成された一連のプリセット ポリシーとカスタム ポリシーを通じて、再実体化をきめ細かく制御できます。
プリセット ポリシー
MaxText には、次の組み込みポリシーが用意されています。
full: 最も積極的なポリシー。ほぼすべてのものを再実体化します。これにより、HBM の使用量は最小限に抑えられますが、再計算のオーバーヘッドは最大になります。メモリが非常に制約された長いシーケンスのシナリオに最適です。minimal: 最もアグレッシブでないポリシーで、ほとんどのアクティベーションを保存します。これにより、HBM の使用率が最大化され、再計算が最小限に抑えられます。メモリが問題にならない短いシーケンスのコンピューティング バウンド ワークロードに最適です。- 中間ポリシー:
save_dot_with_context_except_mlp、save_qkv_proj、save_out_projなどのオプションは、高コストの内積演算の出力を選択的にチェックポイントし、より安価な要素ごとの演算を再実体化することで、さまざまなトレードオフを提供します。
カスタム ポリシー
より詳細に制御するには、remat_policy を custom に設定します。これにより、モデルのデコード モジュール内の個々のレイヤの動作を指定できます。各レイヤには、次の 3 つの動作のいずれかを割り当てることができます。
device: アクティベーションは TPU デバイスの HBM に保存されます。remat: アクティベーションは破棄され、バックワードパスで再実体化されます。offload: 活性化が HBM から CPU ホストのメモリに移動し、PCIe 転送レイテンシと引き換えに HBM が解放されます。
スコープ付き VMEM チューニング
フラッシュ アテンションなどのカーネル パフォーマンスは、カーネルで選択されたタイルサイズに依存します。このサイズは、使用可能なベクトル メモリ(VMEM)によって制限されます。TPU7x チップには 64 MB の VMEM があり、現在のスコープ(スコープ付き VMEM)と将来の重みプリフェッチの間で分割できます。スコープ付き VMEM を増やすと、カーネルのタイルサイズを大きくできるため、メモリの停止を減らし、カーネルのパフォーマンスを向上させることができます。xla_tpu_scoped_vmem_limit_kib(LIBTPU_INIT_ARGS 内)を設定することで、スコープ付き VMEM サイズを変更できます。これは、カーネル パフォーマンスとエンドツーエンドのパフォーマンスの上限を調べるために使用できます。スコープ付き VMEM サイズを最適化すると、カスタム Pallas カーネルのパフォーマンスに間接的に影響する可能性があります。これは、スコープ付き VMEM を増やすと、カーネル内タイルサイズのハイパーパラメータ検索空間が広がるためです。
Tokamax カーネル
Tokamax は、多くの高度に最適化された TPU カーネルを含む高パフォーマンスの JAX カーネル ライブラリであり、一般的なハードウェア固有のボトルネックをいくつか解消します。
- Splash アテンション: Splash アテンションは、標準アテンションの HBM ボトルネックを解消し、TPU で最も効率的なアテンション実装を使用するためのプライマリ アテンション実装として使用されます。
- Megablox グループ化行列乗算(GMM): MoE ワークロードの場合、Megablox は、不均一なアクティベーション表現で計算することで、グループ化された行列乗算を効率的に処理します。不均一なディメンションを効率的にマッピングし、LHS の不均一な行グループと対応するエキスパート行列の間の行列乗算を計算します。バッチを固定サイズにパディングする必要はありません。
tune-jaxを使用した実証的チューニング:tune-jaxライブラリには、最適なブロックサイズを実証的に検索するユーティリティがあります。デフォルトのカーネルサイズは最適でないことが多く、チューニングにより、ハードウェアの使用率を最大化するハードウェア フレンドリーな VMEM タイルサイズを選択できます。- 最大ロジット推定値:
max_logit_constの値を設定することで、Tokamax Splash 注意カーネルをさらに最適化できます。設定すると、注意の softmax オペレーション(softmax(Q * KT))中の最大ロジットの削減計算が置き換えられ、計算と同期のオーバーヘッドが削減されます。MaxText では、use_max_logits_estimate構成によって実装されます。この構成はNone(無効)または浮動小数点値に設定できます。特定のモデルのロジット範囲が推定値と互換性があることを確認して、数値オーバーフローを防ぎます。この値が設定されている場合は、収束テストをおすすめします。