Cloud TPU マルチスライスの概要

Cloud TPU マルチスライスは、標準のデータ並列処理により、単一のスライス内、または複数の Pod 内のスライスで、トレーニング ジョブが複数の TPU スライスを使用できるようにする、フルスタック パフォーマンス スケーリング テクノロジーです。TPU v4 チップでは、トレーニング ジョブは 1 回の実行で 4,096 個を超えるチップを使用できます。4,096 チップ未満を必要とするトレーニング ジョブの場合、単一スライスが最も高パフォーマンスを発揮します。ただし、複数の小さなスライスの方が簡単に利用できるため、マルチスライスを小さなスライスで使用する場合、起動時間が短縮されます。

複数のスライスによるパフォーマンスの線形スケーリング

マルチスライス構成にデプロイすると、各スライス内の TPU チップがチップ間相互接続(ICI)を介して通信します。異なるスライス内の TPU チップは、CPU(ホスト)にデータを転送することで通信します。CPU は、データセンター ネットワーク(DCN)を介してデータを転送します。マルチスライスでのスケーリングの詳細については、マルチスライスで AI トレーニングを最大数万の Cloud TPU チップまでスケーリングする方法をご覧ください。

マルチスライスのデータフロー

スライス間 DCN 通信を実装するためにデベロッパーがコードを記述することはありません。XLA コンパイラが、そのコードを生成し、最大限のパフォーマンスが発揮できるようにコンピューティングと通信をオーバーラップします。

コンセプト

アクセラレータ タイプ
マルチスライスを構成する各 TPU スライスのシェイプ。マルチスライス リクエスト内の各スライスのアクセラレータ タイプは同じです。アクセラレータ タイプは、TPU タイプ(v4 以降)と TensorCore の数で構成されます。たとえば、v5litepod-128 は、TPU v5e と 128 個の TensorCore を示します。
自動修復
スライスにメンテナンス イベント、プリエンプション、またはハードウェアの障害が発生すると、Cloud TPU が新しいスライスを作成します。新しいスライスを作成するのに十分なリソースがない場合、ハードウェアが利用可能になるまで作成は完了しません。新しいスライスを作成すると、マルチスライス環境内の他のすべてのスライスが再起動され、トレーニングを続行できます。適切に構成された起動スクリプトを使用すると、ユーザーの介入なしに、トレーニング スクリプトが自動的に再起動し、最新のチェックポイントから読み込み、再開します。
データセンター ネットワーキング(DCN)
マルチスライス構成で TPU スライスを接続する、高レイテンシ、低スループット(ICI との比較)のネットワーク。
ギャング スケジューリング
すべての TPU スライスが同時にプロビジョニングされた場合に、すべてのスライスが正常にプロビジョニングされるか、いずれのスライスもプロビジョニングされないことを保証します。
インターチップ相互接続(ICI)
TPU Pod 内で TPU を接続する高速かつ低レイテンシの内部リンク。
マルチスライス
DCN を介して通信できる 2 つ以上の TPU チップスライス
ノード
マルチスライスのコンテキストでは、ノードは単一の TPU スライスを指します。マルチスライスの各 TPU スライスにはノード ID が割り当てられます。
起動スクリプト
VM が起動または再起動されるたびに実行される標準の Compute Engine 起動スクリプトマルチスライスの場合、QR 作成リクエストで指定されます。Cloud TPU 起動スクリプトの詳細については、TPU リソースを管理するをご覧ください。
Tensor
ML モデルの多次元データを表すために使用されるデータ構造。
Cloud TPU の容量のタイプ

TPU は、さまざまなタイプの容量から作成できます(TPU の料金の仕組みの使用オプションを参照)。

  • 予約: 予約を使用するには、Google との予約契約が必要です。リソースを作成する際は --reserved フラグを使用します。

  • Spot: Spot VM を使用するプリエンプティブルの割り当てを対象にします。優先度の高いジョブのリクエストに対応できるように、リソースがプリエンプトされる場合があります。リソースを作成する際は --spot フラグを使用します。

  • オンデマンド: 予約を必要とせずプリエンプトされない、オンデマンド割り当てを対象にします。TPU リクエストは、Cloud TPU が提供するオンデマンド割り当てキューに追加されます。リソースの可用性は保証されません。デフォルトで選択されます。フラグは必要ありません。

始める

  1. Cloud TPU 環境を設定します

  2. In the Google Cloud console, activate Cloud Shell.

    Activate Cloud Shell

    At the bottom of the Google Cloud console, a Cloud Shell session starts and displays a command-line prompt. Cloud Shell is a shell environment with the Google Cloud CLI already installed and with values already set for your current project. It can take a few seconds for the session to initialize.

マルチスライスを使用するには、TPU リソースをキューに格納されたリソースとして管理する必要があります。

入門例

このチュートリアルでは、MaxText GitHub リポジトリのコードを使用します。MaxText は、Python と Jax で記述された、高パフォーマンスで任意にスケーラブルなオープンソースの十分にテストされた基本 LLM です。Cloud TPU での効率的なトレーニングを目的として設計されています。

shardings.py のコードは、さまざまな並列化オプションのテストを開始するうえで役立つように設計されています。たとえば、データ並列処理、完全にシャーディングされたデータ並列処理(FSDP)、テンソル並列処理などです。コードは、単一スライス環境からマルチスライス環境にスケーリングされます。

ICI 並列処理

ICI は、単一スライスの TPU を接続する高速相互接続を指します。ICI シャーディングは、スライス内のシャーディングに対応します。shardings.py には、次の 3 つの ICI 並列処理パラメータがあります。

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

これらのパラメータに指定する値によって、各並列化メソッドのシャードの数が決まります。

これらの入力は、ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism がスライス内のチップの数と等しくなるように制限する必要があります。

次の表に、v4-8 で使用可能な 4 チップの ICI 並列処理のユーザー入力の例を示します。

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4 方向 FSDP 1 4 1
4 方向テンソル並列処理 1 1 4
2 方向 FSDP + 2 方向テンソル並列処理 1 2 2

ほとんどの場合、ici_data_parallelism は 1 のままにしておきます。ICI ネットワークは十分高速で、ほぼ常にデータ並列処理よりも FSDP が優先されるためです。

この例は、JAX を使用して Cloud TPU VM で計算を実行するなど、単一の TPU スライスでのコード実行に精通していることを前提としています。この例は、単一のスライスで shardings.py を実行する方法を示しています。

  1. 環境を設定します。

    $ gcloud auth login
    $ export QR_ID=your-queued-resource-id
    $ export TPU_NAME=your-tpu-name
    $ export PROJECT=your-project-name
    $ export ZONE=us-central1-a
    $ export NETWORK_NAME=your-network-name
    $ export SUBNETWORK_NAME=your-subnetwork-name
    $ export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    $ export ACCELERATOR_TYPE=v5litepod-16
    $ export EXAMPLE_TAG_1=your-tag-1
    $ export EXAMPLE_TAG_2=your-tag-2
    $ export SLICE_COUNT=4
    $ export STARTUP_SCRIPT='#!/bin/bash\n'

    変数の説明

    入力 説明
    QR_ID キューに格納されたリソースのユーザー割り当て ID。
    TPU_NAME ユーザーが割り当てた TPU の名前。
    PROJECT Google Cloud プロジェクト名
    ZONE リソースを作成するゾーンを指定します。
    NETWORK_NAME VPC ネットワークの名前。
    SUBNETWORK_NAME VPC ネットワーク内のサブネットの名前。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    ACCELERATOR_TYPE v4-16
    EXAMPLE_TAG_1、EXAMPLE_TAG_2 … ネットワーク ファイアウォールの有効なソースやターゲットを識別するために使用されるタグ。
    SLICE_COUNT スライスの数。上限は 256 スライスです。
    STARTUP_SCRIPT 起動スクリプトを指定すると、TPU スライスがプロビジョニングまたは再起動されたときにスクリプトが実行されます。
  2. gcloud の SSH 認証鍵を作成します。パスワードは空白のままにすることをおすすめします(次のコマンドの実行後に 2 回 Enter を押します)。google_compute_engine ファイルがすでに存在しているというメッセージが表示された場合は、既存のバージョンを置き換えます。

    $ ssh-keygen -f ~/.ssh/google_compute_engine
  3. TPU をプロビジョニングします。

    gcloud

    $ gcloud compute tpus queued-resources \
        create ${QR_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --runtime-version=${RUNTIME_VERSION} \
        --node-id=${TPU_NAME} \
        --zone=${ZONE} \
        [--reserved |--spot]

    Google Cloud CLI では、タグなどの QR コードの作成オプションはサポートされていません。詳細については、QR を作成するをご覧ください。

    コンソール

    1. Google Cloud コンソールで、[TPU] ページに移動します。

      [TPU] に移動

    2. [TPU を作成] をクリックします。

    3. [名前] フィールドに、TPU の名前を入力します。

    4. [ゾーン] ボックスで、TPU を作成するゾーンを選択します。

    5. [TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。

    6. [TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、この TPU ソフトウェア バージョンによって、インストールされる TPU ランタイム バージョンが指定されます。詳細については、TPU ソフトウェア バージョンをご覧ください。

    7. [キューイングを有効にする] トグルをクリックします。

    8. [キューに格納されたリソースの名前] フィールドに、キューに格納されたリソース リクエストの名前を入力します。

    9. [作成] をクリックして、キューに格納されたリソース リクエストを作成します。

  4. キューに格納されたリソースが ACTIVE 状態になるまで待ちます。これは、ワーカーノードが READY 状態であることを意味します。キューに格納されたリソースのプロビジョニングが開始されると、そのサイズによっては、完了までに 1~5 分かかることがあります。キューに格納されたリソース リクエストのステータスは、gcloud CLI または Google Cloud コンソールを使用して確認できます。

    gcloud

    $ gcloud compute tpus queued-resources \
        list --filter=${QR_ID} --zone=${ZONE}

    コンソール

    1. Google Cloud コンソールで、[TPU] ページに移動します。

      [TPU] に移動

    2. [キューに格納されたリソース] タブをクリックします。

    3. キューに格納されたリソース リクエストの名前をクリックします。

  5. SSH を使用して TPU VM に接続します。

    $ gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}
  6. TPU VM に MaxTextshardings.py を含む)のクローンを作成します。

    $ git clone https://github.com/AI-Hypercomputer/maxtext && cd maxtext
  7. Python 3.10 をインストールします。

    $ sudo apt-get update
    $ sudo apt install python3.10
    $ sudo apt install python3.10-venv
  8. 仮想環境を作成して有効にします。

    $ python3 -m venv your-venv-name
    $ source your-venv-name/bin/activate
  9. MaxText リポジトリ ディレクトリ内でセットアップ スクリプトを実行して、TPU スライスに JAX などの依存関係をインストールします。このスクリプトのセットアップには数分かかります。

    $ bash setup.sh
  10. 次のコマンドを実行して、TPU スライスで shardings.py を実行します。

    $ python3 -m pedagogical_examples.shardings \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048

    結果はログで確認できます。TPU は約 260 TFLOP/秒、または 90% 以上の FLOP 使用率を達成します。この例では、TPU の高帯域幅メモリ(HBM)に収まるほぼ最大のバッチを選択しています。

  11. ICI で他のシャーディング戦略を試すこともできます。たとえば、次の組み合わせが可能です。

    $ python3 -m pedagogical_examples.shardings \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
  12. 完了したら、キューに格納されたリソースと TPU スライスを削除します。これらのクリーンアップ手順は、スライスを設定した環境から実行する必要があります(まず exit を実行して SSH セッションを終了します)。削除が完了するまでに 2~5 分かかります。gcloud CLI を使用している場合は、オプションの --async フラグを使用して、このコマンドをバックグラウンドで実行できます。

    gcloud

    $ gcloud compute tpus queued-resources \
        delete ${QR_ID} --force (--async)

    コンソール

    1. Google Cloud コンソールで、[TPU] ページに移動します。

      [TPU] に移動

    2. [キューに格納されたリソース] タブをクリックします。

    3. キューに格納されたリソース リクエストの横にあるチェックボックスをオンにします。

    4. [削除] をクリックします。

DCN 並列処理を使用したマルチスライス シャーディング

shardings.py スクリプトは、データ並列処理の各タイプのシャード数に対応する、DCN 並列処理を指定する 3 つのパラメータを受け取ります。

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

これらのパラメータの値は、dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism がスライス数と等しくなるように制限する必要があります。

たとえばスライスが 2 つの場合は、--dcn_data_parallelism = 2 を使用します。

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism スライス数
2 方向データ並列処理 2 1 1 2

DCN はこのようなシャーディングには適していないため、dcn_tensor_parallelism は常に 1 に設定する必要があります。v4 チップの一般的な LLM ワークロードでは、dcn_fsdp_parallelism1 に設定する必要があるため、dcn_data_parallelism をスライス数に設定する必要がありますが、これはアプリケーションによって異なります。

スライスの数を増やすと(スライスのサイズとスライスごとのバッチ数が一定に保たれていると想定)、データ並列処理の量が増えます。

マルチスライス環境で shardings.py を実行する

マルチスライス環境で shardings.py を実行するには、multihost_runner.py を使用するか、各 TPU VM で shardings.py を実行します。ここでは multihost_runner.py を使用します。次の手順は、MaxText リポジトリからのはじめに: 複数のスライスでの簡単なテストの手順と似ています。ただしここでは、train.py のより複雑な LLM の代わりに shardings.py を実行します。

multihost_runner.py ツールは、同じ TPU を繰り返し再利用する簡単なテストに最適化されています。multihost_runner.py スクリプトは長時間継続する SSH 接続に依存するため、実行時間が長いジョブにはおすすめしません。実行時間が長いジョブ(数時間または数日間など)を実行する場合は、multihost_job.py を使用することをおすすめします。

このチュートリアルでは、multihost_runner.py スクリプトを実行するマシンを示すために「ランナー」という用語を使用します。また、スライスを構成する TPU VM を示すために「ワーカー」という用語を使用します。multihost_runner.py は、ローカルマシン、またはスライスと同じプロジェクト内の任意の Compute Engine VM 上で実行できます。ワーカーでの multihost_runner.py の実行はサポートされていません。

multihost_runner.py は、SSH を使用して TPU ワーカーに自動的に接続します。

この例では、2 つの v5e-16 スライス(合計 4 台の VM と 16 個の TPU チップ)で shardings.py を実行します。より多くの TPU で実行するよう、例を変更できます。

環境を設定する

  1. ランナーマシンに MaxText のクローンを作成します。

    $ git clone https://github.com/AI-Hypercomputer/maxtext
  2. リポジトリ ディレクトリに移動します。

    $ cd maxtext
  3. gcloud の SSH 認証鍵を作成します。パスワードは空白のままにすることことをおすすめします(次のコマンドの実行後に 2 回 Enter を押します)。google_compute_engine ファイルがすでに存在しているというメッセージが表示された場合は、既存のバージョンを保持しないことを選択します。

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. 環境変数を追加して、TPU スライス数を 2 に設定します。

      $ export SLICE_COUNT=2
      

  5. queued-resources create コマンドまたは Google Cloud コンソールを使用して、マルチスライス環境を作成します。

    gcloud

    次のコマンドは、v5e マルチスライス TPU を作成する方法を示しています。別の TPU バージョンを使用するには、別の accelerator-typeruntime-version を指定します。

    $ gcloud compute tpus queued-resources \
        create ${QR_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --runtime-version=${RUNTIME_VERSION} \
        --node-count=${SLICE_COUNT} \
        --node-prefix=${TPU_NAME} \
        --zone=${ZONE} \
        [--reserved|--spot]

    コンソール

    1. Google Cloud コンソールで、[TPU] ページに移動します。

      [TPU] に移動

    2. [TPU を作成] をクリックします。

    3. [名前] フィールドに、TPU の名前を入力します。

    4. [ゾーン] ボックスで、TPU を作成するゾーンを選択します。

    5. [TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされています。TPU バージョンの詳細については、TPU のバージョンをご覧ください。

    6. [TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、この TPU ソフトウェア バージョンによって、TPU VM にインストールされる TPU ランタイムのバージョンが指定されます。詳細については、TPU ソフトウェア バージョンをご覧ください。

    7. [キューイングを有効にする] トグルをクリックします。

    8. [キューに格納されたリソースの名前] フィールドに、キューに格納されたリソース リクエストの名前を入力します。

    9. [マルチスライス TPU にする] チェックボックスをオンにします。

    10. [スライス数] フィールドに、作成するスライスの数を入力します。

    11. [作成] をクリックして、キューに格納されたリソース リクエストを作成します。

  6. キューに格納されたリソースのプロビジョニングが開始されると、そのサイズによっては、完了までに最大 5 分かかることがあります。キューに格納されたリソースが ACTIVE 状態になるまで待ちます。キューに格納されたリソース リクエストのステータスは、gcloud CLI または Google Cloud コンソールを使用して確認できます。

    gcloud

    $ gcloud compute tpus queued-resources list \
        --filter=${QR_ID} --zone=${ZONE} --project=${PROJECT}

    次のような出力が生成されます。

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v5litepod-16             ACTIVE
    ...

    コンソール

    1. Google Cloud コンソールで、[TPU] ページに移動します。

      [TPU] に移動

    2. [キューに格納されたリソース] タブをクリックします。

    3. キューに格納されたリソース リクエストの名前をクリックします。

    QR のステータスが 15 分以上 WAITING_FOR_RESOURCES または PROVISIONING 状態になっている場合は、 Google Cloud アカウント担当者にお問い合わせください。

  7. 依存関係をインストールします。

    $ python3 multihost_runner.py \
        --TPU_PREFIX=${TPU_NAME} \
        --ZONE=${ZONE} \
        --COMMAND="bash setup.sh"
  8. multihost_runner.py を使用して、各ワーカーで shardings.py を実行します。

    $ python3 multihost_runner.py \
        --TPU_PREFIX=${TPU_NAME} \
        --ZONE=${ZONE} \
        --COMMAND="python3 -m pedagogical_examples.shardings \
        --dcn_data_parallelism ${SLICE_COUNT} \
        --ici_fsdp_parallelism 16 \
        --batch_size 131072 \
        --embedding_dimension 2048"

    ログファイルには、約 230 TFLOP/秒のパフォーマンスが記録されます。

    並列処理の構成の詳細については、DCN 並列処理を使用したマルチスライス シャーディングshardings.py をご覧ください。

  9. 完了したら、TPU とキューに格納されたリソースをクリーンアップします。削除が完了するまでに 2~5 分かかります。gcloud CLI を使用している場合は、オプションの --async フラグを使用して、このコマンドをバックグラウンドで実行できます。

ワークロードをマルチスライスにスケーリングする

マルチスライス環境でモデルを実行する前に、次のようにコードを変更します。

マルチスライスに移行する際に必要なコード変更は、これだけです。高いパフォーマンスを実現するには、DCN をデータ並列軸、完全にシャーディングされたデータ並列軸、またはパイプライン並列軸にマッピングする必要があります。パフォーマンスに関する考慮事項とシャーディング戦略の詳細については、パフォーマンスを最大限に高めるためのマルチスライスのシャーディングをご覧ください。

コードがすべてのデバイスにアクセスできることを確認するには、len(jax.devices()) がマルチスライス環境のチップ数と等しいことをアサートします。たとえば、v4-16 の 4 つのスライスを使用している場合、スライスあたり 8 個のチップ × 4 つのスライスがあるため、len(jax.devices()) は 32 を返します。

マルチスライス環境のスライスサイズを選択する

速度を線形的に向上させるには、既存のスライスと同じサイズの新しいスライスを追加します。たとえば、v4-512 スライスを使用する場合、マルチスライスでは、2 番目の v4-512 スライスを追加してグローバル バッチサイズを 2 倍にすることで、約 2 倍のパフォーマンスを実現できます。詳細については、パフォーマンスを最大限に高めるためのマルチスライスのシャーディングをご覧ください。

複数のスライスでジョブを実行する

マルチスライス環境でカスタム ワークロードを実行するには、次の 3 つの方法があります。

  1. テスト用ランナー スクリプト multihost_runner.py を使用する
  2. 本番環境ランナー スクリプト multihost_job.py を使用する
  3. 手動アプローチを使用する

テスト用ランナー スクリプト

multihost_runner.py スクリプトは、既存のマルチスライス環境にコードを配布し、各ホストでコマンドを実行して、ログをコピーし、各コマンドのエラー ステータスを追跡します。multihost_runner.py スクリプトについては、MaxText の README をご覧ください。

multihost_runner.py では永続的な SSH 接続が維持されるため、サイズが小さく比較的実行時間が短いテストにのみ適しています。multihost_runner.py チュートリアルの手順は、ワークロードとハードウェアの構成に合わせて調整できます。

本番環境ランナー スクリプト

ハードウェアの障害やその他のプリエンプションに対する復元力が必要な本番環境ジョブの場合は、Create Queued Resource API と直接統合することをおすすめします。実用的な例として multihost_job.py を使用します。このスクリプトは、適切な起動スクリプトを使用して Created Queued Resource API 呼び出しをトリガーし、トレーニングを実行してプリエンプション時に再開します。multihost_job.py スクリプトについては、MaxText の README をご覧ください。

multihost_job.py は実行ごとにリソースをプロビジョニングする必要があるため、multihost_runner.py のような速いイテレーション サイクルは提供されません。

手動アプローチ

マルチスライス構成でカスタム ワークロードを実行するには、multihost_runner.py または multihost_job.py を使用するか、適宜変更することをおすすめします。ただし、QR コマンドを直接使用して環境をプロビジョニングして管理する場合は、マルチスライス環境を管理するをご覧ください。

マルチスライス環境を管理する

MaxText リポジトリで提供されるツールを使用せずに、QR を手動でプロビジョニングして管理するには、以下のセクションをご覧ください。

キューに格納されたリソースを作成する

gcloud

  1. 次のコマンドを使用して、キューに格納されたリソース リクエストを作成します。

    $ gcloud compute tpus queued-resources \
        create ${QR_ID} \
        --project=${PROJECT} \
        --zone=${ZONE} \
        --node-count=${SLICE_COUNT} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --runtime-version=${RUNTIME_VERSION} \
        --network=${NETWORK_NAME} \
        --subnetwork=${SUBNETWORK_NAME} \
        --tags=${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \
        --metadata=startup-script="${STARTUP_SCRIPT}" \
        [--reserved|--spot]

--reserved--spot、またはデフォルトのオンデマンド割り当てを選択する前に、対象の割り当てがあることを確認してください。割り当てタイプの詳細については、割り当てポリシーをご覧ください。

curl

  1. queued-resource-req.json という名前のファイルを作成して、次の JSON をコピーします。

    {
    "guaranteed": { "reserved": true },
    "tpu": {
        "node_spec": [
        {
        "parent": "projects/your-project-number/locations/your-zone",
            "node": {
            "accelerator_type": "accelerator-type",
            "runtime_version": "tpu-vm-runtime-version",
            "network_config": {
                "network": "your-network-name",
                "subnetwork": "your-subnetwork-name",
                "enable_external_ips": true
            },
            "tags" : ["example-tag-1"]
            "metadata": {
                "startup-script": "your-startup-script"
            }
        },
        "multi_node_params": {
            "node_count": slice-count,
            "node_id_prefix": "your-queued-resource-id"
        }
        }
        ]
    }
    }

    次の値を置き換えます。

    • your-project-number - Google Cloud プロジェクトの番号。
    • your-zone - キューに格納されたリソースを作成するゾーン。
    • accelerator-type - 単一スライスのバージョンとサイズ。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされています。
    • tpu-vm-runtime-version - 使用する TPU VM ランタイム バージョン。
    • your-network-name - キューに格納されたリソースが接続されるネットワーク(省略可)。
    • your-subnetwork-name - キューに格納されたリソースが接続されるサブネットワーク(省略可)。
    • example-tag-1 - 任意のタグ文字列(省略可)。
    • your-startup-script - キューに格納されたリソースが割り当てられるときに実行される起動スクリプト。
    • slice-count - マルチスライス環境内の TPU スライスの数。
    • your-queued-resource-id - キューに格納されたリソースのユーザー指定 ID。

    利用可能なすべてのオプションに関する詳細については、REST キューに格納されたリソース API のドキュメントをご覧ください。

    Spot 容量を使用するには、次のように置き換えます。

    "guaranteed": { "reserved": true }"spot": {} に置き換えます。

    デフォルトのオンデマンド容量を使用するには、この行を削除します。

  2. JSON ペイロードを使用して、キューに格納されたリソース作成リクエストを送信します。

    $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" \
    -H "Content-Type: application/json" \
    -d @queuedresourcereq.json \
    https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-queued-resource-id

    次の値を置き換えます。

    • your-project-id - Google Cloud プロジェクト ID。
    • your-zone - キューに格納されたリソースを作成するゾーン。
    • your-queued-resource-id - キューに格納されたリソースのユーザー指定 ID。

レスポンスは次のようになります。

{
"name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
"metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
},
"done": false
}

name 属性の文字列値の末尾にある GUID 値を使用して、キューに格納されたリソース リクエストに関する情報を取得します。

コンソール

  1. Google Cloud コンソールで、[TPU] ページに移動します。

    [TPU] に移動

  2. [TPU を作成] をクリックします。

  3. [名前] フィールドに、TPU の名前を入力します。

  4. [ゾーン] ボックスで、TPU を作成するゾーンを選択します。

  5. [TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされています。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。

  6. [TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、この TPU ソフトウェア バージョンによって、インストールされる TPU ランタイム バージョンが指定されます。詳細については、TPU ソフトウェア バージョンをご覧ください。

  7. [キューイングを有効にする] トグルをクリックします。

  8. [キューに格納されたリソースの名前] フィールドに、キューに格納されたリソース リクエストの名前を入力します。

  9. [マルチスライス TPU にする] チェックボックスをオンにします。

  10. [スライス数] フィールドに、作成するスライスの数を入力します。

  11. [作成] をクリックして、キューに格納されたリソース リクエストを作成します。

キューに格納されたリソースのステータスを取得する

gcloud

$ gcloud compute tpus queued-resources describe ${QR_ID} --zone=${ZONE}

キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

...
state:
    state: ACTIVE
...

curl

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${YOUR_QR_ID}

キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

{
"name": your-queued-res,
"tpu": {
    "nodeSpec": [
    {
        ... // node 1
    },
    {
        ... // node 2
    },
    ...
    ]
},
...
"state": "ACTIVE"
}

コンソール

  1. Google Cloud コンソールで、[TPU] ページに移動します。

    [TPU] に移動

  2. [キューに格納されたリソース] タブをクリックします。

  3. キューに格納されたリソース リクエストの名前をクリックします。

TPU がプロビジョニングされたら、[TPU] ページに移動して TPU を見つけ、対応するキューに格納されたリソース リクエストの名前をクリックして、キューに格納されたリソース リクエストの詳細を確認することもできます。

まれに、キューに格納されたリソースの状態が FAILED であるのに対し、一部のスライスが ACTIVE である場合があります。この場合は、作成したリソースを削除してから、数分後にもう一度試すか、Google Cloud サポートにお問い合わせください。

SSH を使用して依存関係をインストールする

TPU スライスで JAX コードを実行するでは、単一スライスで SSH を使用して TPU VM に接続する方法について説明しています。SSH を介してマルチスライス環境内のすべての TPU VM に接続し、依存関係をインストールするには、次の gcloud コマンドを使用します。

  $ gcloud compute tpus queued-resources ssh ${QR_ID} \
        --zone=${ZONE} \
        --node=all \
        --worker=all \
        --command="command-to-run" \
        --batch-size=4

この gcloud コマンドは、SSH を使用して、指定されたコマンドを QR 内のすべてのワーカーとノードに送信します。コマンドは 4 つずつのグループでバッチ処理され、同時に送信されます。現在のバッチの実行が完了すると、コマンドの次のバッチが送信されます。いずれかのコマンドでエラーが発生すると、処理が停止し、それ以降のバッチは送信されません。詳細については、キューに格納されたリソースの API リファレンスをご覧ください。使用しているスライスの数がローカル コンピュータのスレッド上限(バッチ上限とも呼ばれる)を超えると、デッドロックが発生します。たとえば、ローカルマシンのバッチ上限が 64 であるとします。64 を超えるスライス(100 など)でトレーニング スクリプトを実行しようとすると、SSH コマンドでスライスがバッチに分割されます。最初の 64 スライスのバッチでトレーニング スクリプトを実行し、スクリプトが完了するまで待ってから、残りの 36 スライスのバッチでスクリプトを実行します。ただし、残りの 36 スライスがスクリプトの実行を開始するまで、最初の 64 スライスのバッチは完了できず、デッドロックが発生します。

このシナリオを回避するには、--command フラグで指定したスクリプト コマンドにアンパサンド(&)を追加して、各 VM でトレーニング スクリプトをバックグラウンドで実行します。この処理を行う場合、スライスの最初のバッチでトレーニング スクリプトを開始すると、制御が直ちに SSH コマンドに戻ります。これにより、SSH コマンドは、残りの 36 スライスのバッチでトレーニング スクリプトの実行を開始できます。バックグラウンドでコマンドを実行する場合は、stdout ストリームと stderr ストリームを適切にパイプ処理する必要があります。同じ QR 内で並列処理を増やすには、--node パラメータを使用して特定のスライスを選択します。

ネットワーク設定

次の手順に沿って、TPU スライスが相互に通信できることを確認します。スライスに JAX をインストールします。詳細については、Cloud TPU Pod スライスで JAX コードを実行するをご覧ください。len(jax.devices()) がマルチスライス環境のチップ数と等しいことをアサートします。そのためには、各スライスで次のコードを実行します。

  $ python3 -c 'import jax; print(jax.devices())'

このコードを v4-16 の 4 つのスライスで実行する場合、スライスごとに 8 つのチップと 4 つのスライスがあり、合計 32 チップ(デバイス)が jax.devices() によって返されます。

キューに格納されたリソースを一覧表示する

gcloud

キューに格納されたリソースの状態を確認するには、queued-resources list コマンドを使用します。

$ gcloud compute tpus queued-resources list --zone=${ZONE}

出力は次のようになります。

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central1-a  4           v5litepod-16             ACTIVE
...

コンソール

  1. Google Cloud コンソールで、[TPU] ページに移動します。

    [TPU] に移動

  2. [キューに格納されたリソース] タブをクリックします。

プロビジョニングされた環境でジョブを開始する

ワークロードを手動で実行するには、SSH を介して各スライスのすべてのホストに接続し、すべてのホストで次のコマンドを実行します。

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=${ZONE} \
    --worker=all \
    --command="command-to-run"

QR をリセットする

ResetQueuedResource API を使用すると、ACTIVE QR 内のすべての VM をリセットできます。VM をリセットすると、マシンのメモリが強制的に消去され、VM は初期状態にリセットされます。ローカルに保存されているデータはそのまま保持され、リセット後に起動スクリプトが呼び出されます。ResetQueuedResource API は、すべての TPU を再起動する場合に便利です。たとえば、トレーニングが停止して、すべての VM をリセットする方がデバッグよりも簡単な場合などに利用できます。

すべての VM のリセットは並列で実行され、ResetQueuedResource オペレーションの完了には 1~2 分かかります。API を呼び出すには、次のコマンドを使用します。

$ gcloud compute tpus queued-resources reset ${QR_ID} --zone=${ZONE}

キューに格納されたリソースを削除する

トレーニング セッションの終了時にリソースを解放するには、キューに格納されたリソースを削除します。削除が完了するまでに 2~5 分かかります。gcloud CLI を使用している場合は、オプションの --async フラグを使用して、このコマンドをバックグラウンドで実行できます。

gcloud

$ gcloud compute tpus queued-resources \
    delete ${QR_ID} --zone=${ZONE} --force [--async]

コンソール

  1. Google Cloud コンソールで、[TPU] ページに移動します。

    [TPU] に移動

  2. [キューに格納されたリソース] タブをクリックします。

  3. キューに格納されたリソース リクエストの横にあるチェックボックスをオンにします。

  4. [削除] をクリックします。

障害からの自動回復

障害が発生した場合、マルチスライスでは、影響を受けたスライスを介入なしで修復し、その後すべてのスライスをリセットできます。影響を受けたスライスが新しいスライスに置き換えられ、残りの正常なスライスはリセットされます。置き換えるスライスを割り当てるための容量がない場合、トレーニングは停止します。

中断後にトレーニングを自動的に再開するには、最後に保存されたチェックポイントをチェックし、読み込む起動スクリプトを指定する必要があります。起動スクリプトは、スライスが再割り当てされるか、VM がリセットされるたびに自動的に実行されます。create QR request API に送信する JSON ペイロードで起動スクリプトを指定します。

次の起動スクリプト(QR を作成するで使用)を使用すると、MaxText トレーニング中に、障害から自動的に回復し、Cloud Storage バケットに保存されているチェックポイントからトレーニングを再開できます。

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 -m MaxText.train MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

これを試す前に、MaxText リポジトリのクローンを作成してください。

プロファイリングとデバッグを行う

プロファイリングは、単一スライス環境とマルチスライス環境で同じです。詳細については、JAX プログラムのプロファイリングをご覧ください。

トレーニングを最適化する

以降のセクションでは、マルチスライス トレーニングを最適化する方法について説明します。

パフォーマンスを最大限に高めるためのマルチスライスのシャーディング

マルチスライス環境でパフォーマンスを最大限に高めるには、複数のスライスにシャーディングする方法を考慮する必要があります。一般的に、選択肢は 3 つあります(データ並列処理、完全にシャーディングされたデータ並列処理、パイプライン並列処理)。モデルのディメンション間でのアクティベーションのシャーディング(テンソル並列処理とも呼ばれます)は、必要なスライス間帯域幅が大きすぎるため推奨されません。上記のいずれの戦略でも、過去にうまくいったスライス内では引き続き使用できます。

最初は純粋なデータ並列処理から始めることをおすすめします。完全にシャーディングされたデータ並列処理を使用すると、メモリ使用量を削減できます。ただし、スライス間の通信で DCN ネットワークを使用するため、ワークロードが遅くなるという欠点があります。パイプライン並列処理は、バッチサイズに基づいて必要な場合にのみ使用します(以下の分析を参照)。

データ並列処理を使用する場面

純粋なデータ並列処理は、ワークロードが正常に実行されているが、複数のスライスにスケーリングしてパフォーマンスを向上させる必要がある場合に適しています。

複数のスライスで強力なスケーリングを行うには、DCN を介して all-reduce を実行するために必要な時間が、バックワード パスを実行するために必要な時間よりも短くなければなりません。DCN はスライス間の通信に使用され、ワークロードのスループットの制限要因となります。

各 v4 TPU チップは、ピーク時に 275 x 1012 FLOPS/秒で動作します。

TPU ホストごとに 4 つのチップがあり、各ホストの最大ネットワーク帯域幅は 50 Gbps です。

つまり、算術強度は 4 × 275 × 10 12 FLOPS ÷ 50 Gbps = 22,000 FLOPS / ビットです。

モデルでは、ステップごとに各パラメータに 32~64 ビットの DCN 帯域幅が使用されます。2 つのスライスを使用する場合、モデルは 32 ビットの DCN 帯域幅を使用します。3 つ以上のスライスを使用する場合、コンパイラによって完全なシャッフル all-reduce 演算が実行され、ステップごとに各パラメータに対して最大 64 ビットの DCN 帯域幅が使用されます。各パラメータに必要な FLOPS の数はモデルによって異なります。具体的には、Transformer ベースの言語モデルの場合、フォワードパスとバックワード パスに必要な FLOPS の数は約 6 x B x P です。ここで、各アルファベットは以下を表します。

  • B: トークンのバッチサイズ
  • P: パラメータの数

パラメータあたりの FLOPS の数は 6 * B で、バックワード パス中のパラメータあたりの FLOPS の数は 4 * B です。

複数のスライス間で強力なスケーリングを行うには、演算強度が TPU ハードウェアの算術強度を上回ることを確認します。演算強度を計算するには、バックワード パス中のパラメータあたりの FLOPS 数を、1 ステップあたりのパラメータごとのネットワーク帯域幅(ビット数)で割ります。 Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

したがって、Transformer ベースの言語モデルで、2 つのスライスを使用する場合は次のようになります。 Operational intensity = 4 * B / 32

3 つ以上のスライスを使用する場合は次のようになります。Operational intensity = 4 * B/64

これにより、Transformer ベースの言語モデルの最小バッチサイズは 176,000~352,000 になります。DCN ネットワークは一時的にパケットをドロップできるため、十分な許容誤差を維持し、Pod あたりのバッチサイズが 350,000(2 つの Pod)~700,000(多数の Pod)の場合にのみ、データ並列処理をデプロイすることをおすすめします。

他のモデルのアーキテクチャでは、(プロファイラを使用してタイミングを設定するか、FLOPS を計算して)スライスあたりのバックワード パスの実行時間を推定する必要があります。その時間を、DCN を介した all-reduce の予想実行時間と比較することで、データ並列処理が有効かどうかを正確に推定できます。

完全にシャーディングされたデータ並列処理(FSDP)を使用する場面

完全にシャーディングされたデータ並列処理(FSDP)は、データ並列処理(ノード間でのデータのシャーディング)とノード間の重みのシャーディングを組み合わせたものです。フォワードパスとバックワード パスのオペレーションごとに、重みが all-gather され、各スライスに必要な重みが付けられます。all-reduce を使用して勾配を同期させる代わりに、勾配が生成されると reduce-scatter が行われます。このようにして、各スライスの該当する重みの勾配のみが得られます。

FSDP では、データ並列処理と同様に、グローバル バッチサイズをスライス数に応じて線形スケーリングする必要があります。FSDP では、スライスの数を増やすとメモリ負荷が軽減されます。これは、スライスあたりの重みとオプティマイザーの状態の数が減少する代わりに、ネットワーク トラフィックが増加し、コレクティブの遅延によるブロッキングの可能性が増大するためです。

実際には、スライス間の FSDP が最適になるのは、スライスあたりのバッチ数を増やす場合、バックワード パス中の再実体化を最小限に抑えるためにより多くのアクティベーションを保存する場合、またはニューラル ネットワークのパラメータ数を増やす場合です。

FSDP の all-gather 演算と all-reduce 演算は DP の演算と似ているため、前のセクションの説明と同じ方法で、FSDP ワークロードが DCN のパフォーマンスによって制限されているかどうかを判断できます。

パイプライン並列処理を使用する場面

パイプライン並列処理は、推奨最大バッチサイズよりも大きいグローバル バッチサイズを必要とする他の並列処理戦略で高いパフォーマンスを達成する場合に重要になります。パイプライン並列処理では、パイプラインを構成するスライス間でバッチを「共有」できます。ただし、パイプライン並列処理には次の 2 つの大きなデメリットがあります。

  1. チップがデータを待機しているため、アイドル状態になる「パイプライン バブル」が発生します。
  2. 効果的なバッチサイズ、算術強度、ひいてはモデルの FLOP 使用率を減少させるマイクロバッチ処理を必要とします。

パイプライン並列処理は、他の並列処理戦略で必要とされるグローバル バッチサイズが大きすぎる場合にのみ使用してください。パイプライン並列処理を試す前に、高パフォーマンスの FSDP を達成するために必要なバッチサイズでサンプルあたりの収束が遅くなるかどうかを実験的に確認することをおすすめします。FSDP の方がモデルの FLOPS 使用率は高くなる傾向にありますが、バッチサイズの増加に伴いサンプルあたりの収束が遅くなる場合は、パイプライン並列処理の方が適している可能性があります。ほとんどワークロードは、パイプライン並列処理のメリットを享受できないほど大きなバッチサイズを許容できますが、そうではない場合もあります。

パイプライン並列処理が必要な場合は、データ並列処理または FSDP と組み合わせることをおすすめします。これにより、DCN レイテンシがスループットに大きく影響しない程度まで、パイプラインあたりのバッチサイズを増やしながら、パイプラインの深さを最小限に抑えることができます。具体的には、N 個のスライスがある場合は、深さ 2 のパイプラインと N/2 個のデータ並列処理のレプリカを検討し、次に深さ 4 のパイプラインと N/4 個のデータ並列処理のレプリカを検討するというように、DCN のコレクティブがバックワード パスの算術の背後に隠れるほど十分にパイプラインあたりのバッチが大きくなるまで同様の手順を繰り返します。これにより、パイプライン並列処理によって生じる遅延を最小限に抑えながら、グローバル バッチサイズの上限を超えてスケーリングできます。

マルチスライスに関するベスト プラクティス

以降のセクションでは、マルチスライス トレーニングのベスト プラクティスについて説明します。

データ読み込み

トレーニング中は、データセットからバッチを繰り返し読み込み、モデルにフィードします。TPU 使用率の低下を回避するには、バッチをホスト間でシャーディングする効率的な非同期データローダーが必要です。MaxText の現在のデータローダーでは、各ホストがサンプルの同等のサブセットを読み込みます。このソリューションはテキストには適していますが、モデル内での再シャーディングが必要です。また、MaxText では、データ イテレータがプリエンプションの前後に同じデータを読み込むことができる決定的スナップショットの作成機能はまだ提供されていません。

チェックポインティング

Orbax チェックポインティング ライブラリは、JAX PyTree をローカル ストレージまたは Google Cloud ストレージにチェックポインティングするためのプリミティブを提供します。checkpointing.py では、MaxText への同期チェックポインティングを備えたリファレンスのインテグレーションが提供されます。

サポートされている構成

以降のセクションでは、マルチスライスでサポートされているスライス形状、オーケストレーション、フレームワーク、並列処理について説明します。

形状

すべてのスライスが同じ形状(たとえば同じ AcceleratorType)である必要があります。異種スライス形状はサポートされていません。

オーケストレーション

オーケストレーションは GKE でサポートされています。詳細については、GKE の TPU をご覧ください。

フレームワーク

マルチスライスは、JAX と PyTorch のワークロードのみをサポートします。

並列処理

データ並列処理でマルチスライスをテストすることをおすすめします。マルチスライスを使用したパイプライン並列処理の実装の詳細については、Google Cloud アカウント担当者にお問い合わせください。

サポートとフィードバック

フィードバックをぜひお寄せください。フィードバックを共有したり、サポートをリクエストしたりするには、Cloud TPU サポートまたはフィードバック フォームを使用してご連絡ください。