JAX ワークロードを Pathways に移植する

Pathways を使用した JAX は分散型であるため、通信オーバーヘッドにより一部のオペレーションが適切にスケーリングされない可能性があります。Pathways は非同期ディスパッチなどの機能でこれらの オーバーヘッドを最小限に抑えますが、JAX ワークロードを Pathways に移植する場合や、Pathways を使用した JAX ワークロードを多数のアクセラレータにスケーリングする場合は、注意すべき点があります。

始める前に

インストールに必要なもの:

プロセス インデックス

Pathways を使用した JAX は、Pathways クラスタ内のすべてのデバイスをローカルとして扱います。 これにより、デバイス管理が簡素化され、JAX は利用可能なすべてのリソースを使用できます。実際には、次のようになります。

  • すべてのデバイスで jax.process_index() が常に 0 になります。
  • jax.devices()jax.local_devices() は、ジョブ全体で TPU デバイスをすべて返します。

ハードウェアのタイプとコロケーション

最適なパフォーマンスを得るには、すべての Pathways コンポーネントとユーザー ジョブを 同じ Google Cloud クラウド ゾーンに配置します。IFRT プロキシやリソース マネージャーなどの大規模な CPU を使用します。64 個の vCPU と 256 GB のメモリを備えた専用の n2-standard-64 以上をおすすめします。

PathwaysUtils

Pathways-utils は Python ベースの GitHub リポジトリで、JAX ワークロードを Cloud 上の Pathways アーキテクチャにデプロイして実行する際に役立つ重要なユーティリティとツールを提供します。このパッケージは、クラウド環境に必要な適応を処理するため、JAX デベロッパーはプラットフォーム固有の構成を最小限に抑えて、コアの機械学習ワークフローに集中できます。具体的には、次の機能があります。

  • 「プロキシ」JAX バックエンド: このカスタム バックエンドを使用すると、JAX_PLATFORMS=proxy 環境変数を設定して、JAX アプリケーションで Pathways インフラストラクチャを使用できます。
  • 統合プロファイリング ユーティリティ: アプリケーションのパフォーマンスを把握できるプロファイリング機能。jax.profiler.start_tracejax.profiler.start_server などの標準の JAX プロファイリング API を使用すると、JAX コードだけでなく、基盤となる Pathways コンポーネントもプロファイリングできるため、クラウド環境での実行を包括的に把握できます。
  • Orbax を使用した分散チェックポイント: カスタムの Orbax チェックポイント ハンドラを使用すると、Pathways 環境で Orbax ライブラリを使用するときに、分散チェックポイントを使用してチェックポイントを復元できます。この統合は、pathwaysutils をインポートする限り、既存の Orbax チェックポイント コードを変更せずに動作することを目的としています。
  • エラスティック トレーニング プリミティブ: Pathways を使用して堅牢でスケーラブルなトレーニング ワークフローを構築するために使用できる、基本的なエラスティック トレーニング プリミティブを提供します。これらのプリミティブを使用すると、トレーニング ジョブは利用可能なリソースの変更に動的に適応し、クラウド環境での効率と復元力を向上させることができます。

チェックポイント処理

Orbax は、Cloud Storage を使用した 分散チェックポイントと復元のために、Pathways で徹底的にテストされています。 で train.py を呼び出すと、 ArrayHandler を介してチェックポイント オペレーションを効率的に処理するカスタム IFRT プロキシが登録され、アクセラレータ上の Pathways ワーカーがデータを直接保存して復元できるようになります。import pathwaysutils; pathwaysutils.initialize()

コロケーションされた Python

コロケーションされた Python は、ユーザー指定の Python コードを TPU または GPU ホストで直接実行できるオープンソースの JAX API です。これは、マルチコントローラ JAX でより簡単です。これにより、データ読み込みやチェックポイント処理などのコンピューティング負荷の高いタスクで、クライアントと TPU マシン間のデータ移転を回避できます。コロケーションされた Python JAX API を実行するように Pathways クラスタを構成するには、 コロケーションされた Python README の手順に沿って操作します。この手順では、Pathways ワーカーとともに コロケーションされた Python サイドカーを起動する方法について説明します。

データ読み込み

トレーニング中は、データセットからバッチを繰り返し読み込み、モデルにフィードします。アクセラレータの使用率の低下を回避するには、バッチをホスト間でシャーディングする効率的な非同期データローダーが必要です。Pathways でトレーニングを実行する場合、データローダーは CPU VM で実行され(マルチコントローラ セットアップで使用される TPU VM とは異なります)、TPU VM にデータをディスパッチします。これにより、データの読み取りのレイテンシが高くなりますが、CPU ホストで事前に X 個のバッチを読み取り、読み取ったデータを TPU に非同期でディスパッチすることで、この問題は部分的に軽減されます。このソリューションは、小規模から中規模で実行する場合に十分です。

大規模な環境で最適なパフォーマンスを得るには、コロケーションされた Python を使用してアクセラレータでデータパイプラインを直接実行し、入力 データ パイプラインをコロケーションすることを強くおすすめします。これにより、CPU のボトルネックが解消され、TPU の高速インターコネクトがデータ転送に活用されます。

TFDS ベースの 入力パイプラインを移行するリファレンス実装は、RemoteIterator 実装にあります。 multihost_dataloading.py。 この実装は、コロケーションされた Python JAX API を使用して、マルチコントローラ JAX と Pathways の両方で分散型で動作します。

Jax のバージョニング

Pathways のリリースは、互換性と安定性を確保するために JAX バージョンと密接に連携しています。潜在的な問題を回避するには、Pathways アーティファクトと JAX バージョンが一致していることを確認します。Pathways の各リリースでは、互換性のある JAX バージョンが 形式のタグを使用して明確に指定されていますjax-<version>

コンパイル キャッシュ

Pathways 永続コンパイル キャッシュは、Pathways サーバーがコンパイル済みの XLA 実行可能ファイルを Cloud Storage などの永続的な場所に保存して、冗長なコンパイルを回避できる機能です。この機能はデフォルトで有効になっています。キャッシュの場所は、--gcs_scratch_location フラグとしてリソース マネージャーと Pathways ワーカー コンテナに渡されます。関連するストレージ費用を最小限に抑えるため、キャッシュは Cloud Storage の保管場所にライフサイクル ポリシーを適用します。Cloud Storage バケットあたりのポリシー数は 50 個に制限されています。そのため、すべてのワークロードで共通の Cloud Storage の場所を使用することをおすすめします。

このキャッシュは、Pathways ワークロードで pathwaysutils.initialize() によって無効になる JAX コンパイル キャッシュ に似ています。

コンパイル キャッシュには、次の Cloud Storage 権限が必要です。

  • storage.buckets.get: バケット メタデータを取得します。
  • storage.buckets.update: Pathways がオブジェクト ライフサイクル ポリシーを設定して、キャッシュの削除に TTL を適用するために不可欠です。
  • storage.objects.list: バケット内の既存のキャッシュ オブジェクトを一覧表示します。
  • storage.objects.create: 新しいコンパイル済み実行可能ファイルをキャッシュに書き込みます。
  • storage.objects.get: バケットからキャッシュに保存された実行可能ファイルを読み取ります。

プロファイリング

JAX プロファイラを使用して、JAX プログラムのトレースを生成できます。Pathways でサポートされている一般的な方法は 2 つあります。

  • プログラマティック
    • JAX コードからプログラムでプロファイルをキャプチャする
  • 手動
    • JAX コードからプロファイラ サーバーを起動した後、オンデマンドでプロファイルをキャプチャする

どちらの場合も、プロファイルは Cloud Storage バケットに書き込まれます。 Cloud Storage バケットには、タイムスタンプ フォルダごとに複数のトレース ファイルが作成されます。次に例を示します。

  • トレースを呼び出したメインの Python プロセス(通常はノートブック VM): <jax-client-vm-name>.xplane.pb
  • Pathways IFRT プロキシ: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways リソース マネージャー: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways ワーカー: server.*<tpu-node-name>.xplane.pb

これらのトレース ファイルは、次のコマンドを実行して TensorBoard で分析できます。TensorBoard とそのすべてのプロファイリング ツールの詳細については、 プロファイラを使用して TensorFlow のパフォーマンスを最適化するをご覧ください。

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

次のように置き換えます。

  • BUCKET : トレース ファイルを保存する Cloud Storage バケット
  • PREFIX: トレース ファイルを保存する Cloud Storage バケット内のパス

プログラムによるプロファイル キャプチャ

コード内からプロファイルをキャプチャします。プロファイルは、タイムスタンプ ディレクトリの gs://<bucket>/<prefix>に保存されます。

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

手動でのプロファイル キャプチャ

プロファイル情報を手動でキャプチャするには、Python コードからプロファイラ サーバーを起動する必要があります。

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

プロファイラ サーバーの実行中に、プロファイルをキャプチャして、ターゲットの Cloud Storage の保管場所にデータをエクスポートできます。

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

プログラムのトレース内で、CompileExecute などの IFRT プロキシ クライアント メソッドのタイミング情報を確認できます。コンパイルと実行中に IFRT プロキシ gRPC サーバーとのインタラクションの詳細を示すこれらのイベントは、GrpcClientSessionUserFuturesWorkQueue という名前のスレッドに表示されます。トレースでこのスレッドを調べることで、これらのオペレーションのパフォーマンスに関する分析情報を得ることができます。

XLA フラグ

Pathways を使用する場合は、pathways-proxy コンテナに XLA フラグを設定する必要があります。これを行うには、XPK または PathwaysJob API を使用します。

XPK を使用する場合は、次のように XLA フラグを設定します。

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

PathwaysJob API を使用する場合は、次のように XLA フラグを設定します。

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

次のように置き換えます。

  • USER : ユーザー Google Cloud 名。
  • value[n]: 設定する XLA フラグ。

HLO ダンプ

XLA コンパイラに渡される High Level Optimizer(HLO)入力を詳細解説するには、次のように Pathways を構成して、指定した Cloud Storage の保管場所に HLO をダンプします。

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

次のステップ