このガイドでは、複数のノードで Tensor Processing Unit(TPU)を使用して、Google Kubernetes Engine(GKE)で Llama 3.1 405B などの最先端の大規模言語モデル(LLM)をサービングする方法について説明します。
このガイドでは、ポータブルなオープンソース テクノロジー(Kubernetes、JetStream、Pathways on Cloud、LeaderWorkerSet(LWS)API)を使用して、GKE のきめ細かい制御、拡張性、復元力、移植性、費用対効果を活用して、GKE に AI / ML ワークロードをデプロイしてサービングする方法について説明します。
背景
大規模言語モデルのサイズが大きくなり、単一のホスト TPU スライスに収まらなくなりました。ML 推論では、Cloud 上の Pathways を使用して、相互接続された複数の TPU ノードにまたがる GKE で大規模なマルチホスト推論を実行できます。このガイドでは、マルチホスト TPU スライスを使用して GKE クラスタをプロビジョニングし、Pathways on Cloud バイナリを使用して MaxText フレームワークで JetStream サーバーを起動し、マルチホスト推論リクエストを行う方法について説明します。
GKE で TPU を使用して JetStream、MaxText、Pathways で LLM をサービングすることで、マネージド Kubernetes のメリット(費用効率、拡張性、高可用性など)をすべて活用した、プロダクション レディな堅牢なサービング ソリューションを構築できます。このセクションでは、このチュートリアルで使用されている重要なテクノロジーについて説明します。
TPU について
TPU は、Google が独自に開発した特定用途向け集積回路(ASIC)であり、TensorFlow、PyTorch、JAX などのフレームワークを使用して構築された ML モデルと AI モデルを高速化するために使用されます。
GKE で TPU を使用する前に、次の学習プログラムを完了することをおすすめします。
- Cloud TPU システム アーキテクチャで、現在の TPU バージョンの可用性について学習する。
- GKE の TPU についてを確認する。
このチュートリアルでは、Llama 3.1-405B モデルのサービングについて説明します。GKE は、低レイテンシでプロンプトをサービングするモデルの要件に基づいて構成された TPU トポロジを使用して、マルチホスト TPU v6e ノードにモデルをデプロイします。
Pathways on Cloud
Pathways は、アクセラレータの大規模なオーケストレーション レイヤです。Pathways は、現在のモデルの最先端のパフォーマンスを維持しながら、新しいシステムと ML 研究のアイデアの探求を可能にするように明示的に設計されています。Pathways を使用すると、単一の JAX クライアント プロセスで 1 つ以上の大規模な TPU スライスにまたがる計算を調整できるため、数百または数千の TPU チップにまたがる ML コンピューティングを効率化できます。
JetStream
JetStream は、Google が開発したオープンソースの推論サービング フレームワークです。JetStream を使用すると、TPU と GPU で高性能、高スループット、メモリ最適化された推論が可能になります。JetStream では、連続バッチ処理、KV キャッシュの最適化、量子化手法などの高度なパフォーマンス最適化により、LLM を簡単にデプロイできます。JetStream では、PyTorch/XLA と JAX TPU のサービングにより、パフォーマンスを最適化できます。
MaxText
MaxText は、Flax、Orbax、Optax などのオープンソースの JAX ライブラリ上に構築された、パフォーマンス、スケーラビリティ、適応性に優れた JAX LLM 実装です。MaxText のデコーダ専用の LLM 実装は Python で記述されています。XLA コンパイラの活用により、カスタム カーネルを構築しなくても高いパフォーマンスを実現できます。
MaxText がサポートする最新のモデルとパラメータ サイズの詳細については、MaxText プロジェクト リポジトリをご覧ください。
Llama 3.1 405B
Llama 3.1 405B は、テキスト生成、翻訳、質問応答など、さまざまな自然言語処理タスク用に設計された Meta の大規模言語モデルです。GKE は、この規模のモデルの分散トレーニングとサービングの実現に必要なインフラストラクチャを提供します。
詳細については、Llama のドキュメントをご覧ください。
アーキテクチャ
このセクションでは、このチュートリアルで使用する GKE アーキテクチャについて説明します。このアーキテクチャには、TPU をプロビジョニングし、モデルをデプロイしてサービングするための JetStream コンポーネントと Pathways コンポーネントをホストする GKE Standard クラスタが含まれています。
次の図は、このアーキテクチャのコンポーネントを示しています。
このアーキテクチャには次のコンポーネントが含まれています。
- GKE Standard リージョン クラスタ。
- JetStream デプロイと Pathways コンポーネントをホストするマルチホスト TPU スライス ノードプール。
- アクセラレータ リソースを管理し、ユーザージョブのアクセラレータの割り当てを調整する
Pathways resource manager。 Pathways resource managerと連携してコンパイルされたプログラムを実行する場所を決定するPathways client。- アクセラレータ マシンで実行されて計算を行い、IFRT プロキシ サーバーを介してワークロードにデータを送り返す
Pathways worker。 - OSS の Interim Framework Runtime(IFRT) API を実装し、ワークロードと Pathways コンポーネント間の通信ブリッジとして機能する
IFRT proxy client。 IFRT proxy clientからリクエストを受け取り、Pathways clientに転送して作業を分散するIFRT proxy server。- 推論リクエストを受け取り、実行プロセスを
Pathways workersに委任する JAX ベースの推論サーバーを提供するJetStream-Pathwaysコンテナ。 - Service コンポーネントは、インバウンド トラフィックをすべての
JetStream HTTPレプリカに分散します。 JetStream HTTPは、JetStream の必須フォーマットのラッパーとしてリクエストを受け取り、JetStream の GRPC クライアントに送信する HTTP サーバーです。