Pathways を使用した JAX の分散型のため、通信オーバーヘッドにより一部のオペレーションがうまくスケーリングされないことがあります。Pathways は、非同期ディスパッチなどの機能でこれらのオーバーヘッドを最小限に抑えますが、JAX ワークロードを Pathways に移植したり、Pathways ワークロードで JAX を多数のアクセラレータにスケーリングしたりする際には、注意すべき点があります。
始める前に
インストールに必要なもの:
- インストールされている Kubernetes ツール
- gcloud CLI がインストールされている
- TPU API を有効にした
- Google Kubernetes Engine API を有効にした
プロセス インデックス
Pathways を使用する JAX は、Pathways クラスタ内のすべてのデバイスをローカルとして扱います。これにより、デバイス管理が簡素化され、JAX は利用可能なすべてのリソースを利用できるようになります。実際には、次のようになります。
jax.process_index()はすべてのデバイスで常に 0 です。jax.devices()とjax.local_devices()は、ジョブ全体のすべての TPU デバイスを返します。
ハードウェアのタイプとコロケーション
パフォーマンスを最大限に高めるには、すべての Pathways コンポーネントとユーザー ジョブを同じ Google Cloud クラウドゾーンに配置します。IFRT プロキシやリソース マネージャーなどの大きな CPU を使用します。64 個の vCPU と 256 GB のメモリを備えた専用の n2-standard-64 を使用することをおすすめします。
PathwaysUtils
Pathways-utils は、Python ベースの GitHub リポジトリです。このリポジトリには、Cloud アーキテクチャ上の Pathways で JAX ワークロードのデプロイと実行を効率化するための重要なユーティリティとツールが用意されています。このパッケージは、クラウド環境に必要な適応処理を行うため、JAX デベロッパーはプラットフォーム固有の構成を最小限に抑えながら、コアの ML ワークフローに集中できます。具体的には、次の機能を提供します。
- 「プロキシ」JAX バックエンド: このカスタム バックエンドを使用すると、
JAX_PLATFORMS=proxy環境変数を設定することで、JAX アプリケーションで Pathways インフラストラクチャを使用できます。 - 統合プロファイリング ユーティリティ: アプリケーションのパフォーマンスを把握できるプロファイリング機能。
jax.profiler.start_traceやjax.profiler.start_serverなどの標準の JAX プロファイリング API を使用すると、JAX コードだけでなく、基盤となる Pathways コンポーネントもプロファイリングできるため、クラウド環境内の実行を包括的に把握できます。 - Orbax を使用した分散チェックポインティング: Pathways 環境内で Orbax ライブラリを使用する際に、分散チェックポイントを使用してチェックポイントを復元できるカスタム Orbax チェックポイント ハンドラ。この統合は、
pathwaysutilsをインポートする限り、既存の Orbax チェックポイント コードを変更することなく機能することを目的としています。 - Elastic Training Primitives: Pathways を使用して堅牢でスケーラブルなトレーニング ワークフローを構築するために使用できる、基本的なエラスティック トレーニング プリミティブを提供します。これらのプリミティブを使用すると、トレーニング ジョブが利用可能なリソースの変更に動的に適応し、クラウド環境の効率と復元力が向上します。
チェックポイント処理
Orbax は、Cloud Storage を使用した分散チェックポインティングと復元のために Pathways で徹底的にテストされています。train.py で import pathwaysutils; pathwaysutils.initialize() を呼び出すと、IFRT プロキシを介してチェックポイント オペレーションを効率的に処理するカスタム ArrayHandler が登録され、アクセラレータ上の Pathways ワーカーがデータを直接保存して復元できるようになります。
コロケーションされた Python
Colocated Python は、ユーザーが指定した Python コードを TPU または GPU ホストで直接実行できるオープンソースの JAX API です。これは、マルチコントローラ JAX でより簡単に行えます。これにより、データ読み込みやチェックポイントなどのコンピューティング負荷の高いタスクで、クライアントと TPU マシン間のデータ転送を回避できます。同じ場所に配置された Python JAX API を実行するように Pathways クラスタを構成するには、同じ場所に配置された Python の README の手順に沿って操作します。これらの手順では、Pathways ワーカーとともに同じ場所に配置された Python サイドカーを起動する方法について説明します。
データ読み込み
トレーニング中は、データセットからバッチを繰り返し読み込み、モデルにフィードします。アクセラレータの使用率の低下を回避するには、バッチをホスト間でシャーディングする効率的な非同期データローダーが必要です。Pathways でトレーニングを実行する場合、データローダは CPU VM で実行され(マルチコントローラ設定で使用される TPU VM とは異なります)、データを TPU VM にディスパッチします。これにより、データの読み取りのレイテンシが大きくなりますが、CPU ホストで X 個のバッチを事前に読み取り、読み取ったデータを TPU に非同期でディスパッチすることで、部分的に軽減されます。このソリューションは、小規模から中規模で実行する場合に十分です。
スケーリング時に最適なパフォーマンスを実現するには、コロケーションされた Python を使用して、アクセラレータでデータ パイプラインを直接実行し、入力データ パイプラインをコロケーションすることを強くおすすめします。これにより、CPU のボトルネックが解消され、TPU の高速インターコネクトがデータ転送に活用されます。
TFDS ベースの入力パイプラインを移行するリファレンス実装は、multihost_dataloading.py の RemoteIterator 実装にあります。この実装は、コロケーションされた Python JAX API を使用して、マルチ コントローラ JAX と Pathways の両方で分散方式で動作します。
Jax のバージョニング
Pathways のリリースは、互換性と安定性を確保するために JAX バージョンと密接に連携しています。潜在的な問題を回避するため、Pathways アーティファクトと JAX バージョンが一致していることを確認してください。Pathways の各リリースでは、jax-<version> 形式のタグを使用して、互換性のある JAX バージョンを明確に指定しています。
コンパイル キャッシュ
Pathways 永続コンパイル キャッシュは、Pathways サーバーがコンパイル済みの XLA 実行可能ファイルを Cloud Storage などの永続的な場所に保存して、冗長なコンパイルを回避できるようにする機能です。この機能はデフォルトで有効になっています。キャッシュの場所は、--gcs_scratch_location フラグとしてリソース マネージャーと Pathways ワーカー コンテナに渡されます。関連するストレージ費用を最小限に抑えるため、キャッシュはライフサイクル ポリシーを Cloud Storage のロケーションに適用します。Cloud Storage バケットあたりのポリシー数は 50 個までです。そのため、すべてのワークロードで共通の Cloud Storage ロケーションを使用することをおすすめします。
このキャッシュは、Pathways ワークロードで pathwaysutils.initialize() によって無効にされる JAX コンパイル キャッシュに似ています。
プロファイリング
JAX プロファイラを使用して、JAX プログラムのトレースを生成できます。Pathways でサポートされている一般的な方法は 2 つあります。
どちらの場合も、プロファイルは Cloud Storage バケットに書き込まれます。Cloud Storage バケットには、複数のトレース ファイルが作成されます。これらのファイルは、異なるタイムスタンプ フォルダに保存される可能性があります。次に例を示します。
- トレースを呼び出したメインの Python プロセス(通常はノートブック VM):
<jax-client-vm-name>.xplane.pb - Pathways IFRT プロキシ:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways Resource Manager:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways ワーカー:
server.*<tpu-node-name>.xplane.pb
これらのトレース ファイルは、次のコマンドを実行して TensorBoard で分析できます。TensorBoard とそのすべてのプロファイリング ツールについて詳しくは、プロファイラを使用して TensorFlow のパフォーマンスを最適化するをご覧ください。
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
次のように置き換えます。
BUCKET: トレース ファイルを保存する Cloud Storage バケットPREFIX: トレース ファイルを保存する Cloud Storage バケット内のパス
プログラムによるプロファイルのキャプチャ
コード内からプロファイルをキャプチャします。プロファイルは gs://<bucket>/<prefix> 内のタイムスタンプ ディレクトリに保存されます。
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
手動プロファイル キャプチャ
プロファイル情報を手動でキャプチャするには、Python コードからプロファイラ サーバーを起動する必要があります。
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
プロファイラ サーバーの実行中に、プロファイルをキャプチャして、データをターゲットの Cloud Storage ロケーションにエクスポートできます。
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
プログラムのトレース内で、Compile や Execute などの IFRT プロキシ クライアント メソッドのタイミング情報を確認できます。コンパイルと実行中の IFRT Proxy gRPC サーバーとのやり取りの詳細を示すこれらのイベントは、GrpcClientSessionUserFuturesWorkQueue という名前のスレッドに表示されます。トレースでこのスレッドを調べると、これらのオペレーションのパフォーマンスに関する分析情報を得ることができます。
XLA フラグ
Pathways を使用する場合は、pathways-proxy コンテナで XLA フラグを設定する必要があります。これを行うには、XPK または PathwaysJob API を使用します。
XPK を使用する場合は、次のように XLA フラグを設定します。
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
PathwaysJob API を使用する場合は、次のように XLA フラグを設定します。
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
次のように置き換えます。
USER: Google Cloud ユーザー名value[n]: 設定する XLA フラグ
HLO ダンプ
XLA コンパイラに渡される High Level Optimizer(HLO)入力を詳しく調べるには、次のように、指定した Cloud Storage の場所に HLO をダンプするように Pathways を構成します。
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
次のステップ
- Pathways を使用して GKE クラスタを作成する
- Pathways を使用したマルチホスト推論
- Pathways を使用したバッチ ワークロード
- Pathways インタラクティブ モード
- Pathways を使用した復元力トレーニング
- トラブルシューティングのパス