GKE の TPU で JetStream と PyTorch を使用して LLM をサービングする

このガイドでは、Google Kubernetes Engine(GKE)で Tensor Processing Unit(TPU)を使用し、JetStreamPyTorch で大規模言語モデル(LLM)をサービングする方法について説明します。このガイドでは、モデルの重みを Cloud Storage にダウンロードし、JetStream を実行するコンテナを使用して GKE Autopilot または Standard クラスタにデプロイします。

モデルを JetStream にデプロイするときに Kubernetes の機能によって実現されるスケーラビリティ、復元力、費用対効果が必要な場合は、このガイドが適しています。

このガイドは、PyTorch を使用している生成 AI をご利用のお客様、GKE の新規または既存のユーザー、ML エンジニア、MLOps(DevOps)エンジニア、LLM のサービングに Kubernetes コンテナのオーケストレーション機能を使用することに関心をお持ちのプラットフォーム管理者を対象としています。

背景

GKE で TPU を使用して JetStream で LLM をサービングすることで、マネージド Kubernetes のメリット(費用効率、スケーラビリティ、高可用性など)をすべて活用した、プロダクション レディな堅牢なサービング ソリューションを構築できます。このセクションでは、このチュートリアルで使用されている重要なテクノロジーについて説明します。

TPU について

TPU は、Google が独自に開発した特定用途向け集積回路(ASIC)であり、TensorFlowPyTorchJAX などのフレームワークを使用して構築された ML モデルと AI モデルを高速化するために使用されます。

GKE で TPU を使用する前に、次の学習プログラムを完了することをおすすめします。

  1. Cloud TPU システム アーキテクチャで、現在の TPU バージョンの可用性について学習する。
  2. GKE の TPU についてを確認する。

このチュートリアルでは、さまざまな LLM モデルのサービングについて説明します。GKE は、低レイテンシでプロンプトをサービングするモデルの要件に基づいて構成された TPU トポロジを使用して、単一ホストの TPUv5e ノードにモデルをデプロイします。

JetStream について

JetStream は、Google が開発したオープンソースの推論サービング フレームワークです。JetStream を使用すると、TPU と GPU で高性能、高スループット、メモリ最適化された推論が可能になります。JetStream では、連続バッチ処理、KV キャッシュの最適化、量子化手法などの高度なパフォーマンス最適化により、LLM を簡単にデプロイできます。JetStream では、PyTorch / XLA と JAX TPU のサービングにより、最適なパフォーマンスを実現できます。

連続的なバッチ処理

連続的バッチ処理は、受信した推論リクエストを動的にバッチにグループ化し、レイテンシを短縮してスループットを向上させる手法です。

KV キャッシュの量子化

KV キャッシュの量子化では、アテンション機構で使用される Key-Value キャッシュを圧縮して、メモリ要件を削減します。

Int8 重み量子化

Int8 重み量子化では、モデル重みの精度を 32 ビットの浮動小数点数から 8 ビットの整数にすることで、計算速度を向上させ、メモリ使用量を削減しています。

これらの最適化の詳細については、JetStream PyTorchJetStream MaxText のプロジェクト リポジトリをご覧ください。

PyTorch について

PyTorch は、Meta によって開発されたオープンソースの ML フレームワークで、現在は Linux Foundation 傘下にあります。PyTorch は、テンソル計算やディープ ニューラル ネットワークなどの高度な機能を提供します。