Pathways を使用した復元力のあるトレーニング

Pathways は、次のような方法で復元力のメリットを提供します。

  • 一時停止と再開: ユーザーがカスタム プリエンプション処理コードを記述しなくても、プリエンプション通知などの計画的な中断に対する耐性。
  • Elastic Training: クライアントをクラッシュさせることなく、計画外のハードウェア障害に対する耐性がありますが、ユーザーはモデル固有の復元コードを記述する必要があります。

    始める前に

    インストールに必要なもの:

    一時停止と再開

    通常、GKE は Pod がプリエンプトされる前に、アクセラレータ Pod にプリエンプション通知を送信します。パスウェイのプリエンプション許容範囲は、すべてのクラウド デプロイでデフォルトで有効になっており、パスウェイ アクセラレータ ジョブはこれらの通知をリッスンします。

    プリエンプション通知が届くと、Pathways はまず現在のワークロードが復元可能かどうか、つまり Pathways がワークロードを透過的に保存して復元できるかどうかを判断します。その場合、GKE がアクセラレータ ジョブを削除する前に、現在の状態を Cloud Storage などの永続ストレージに書き出して、ML ワークロードを透過的に一時停止しようとします。GKE が後でジョブを再スケジュールすると、Pathways は永続化された状態を読み取って ML ワークロードを再開します。

    ワークロードを復元できない場合、Pathways はアクセラレータ ジョブをシャットダウンし、エラスティック トレーニングが構成されている場合は、障害をジョブに転送します。Elastic トレーニングが構成されていない場合、GKE は JobSet 再起動ポリシーに基づいてワークロード全体を再起動します。

    JAX を使用して定義された一般的な ML ワークロードは、高帯域幅メモリ(HBM)スナップショットを使用して復元可能なステートレス Pathways XLA コンポーネントに依存しています。JAX のコロケーションされた Python API を使用して定義された ML ワークロードなど、一部の ML ワークロードはステートフル Pathways コンポーネントに依存しています。これらのコンポーネントは復元できません。

    弾力性のあるトレーニング

    エラスティック トレーニングを使用すると、ハードウェア障害が発生してもトレーニング ジョブを続行できます。これは、Pathways システムの機能とユーザー定義のモデル復元ロジックの組み合わせによって実現されます。

    • 障害の検出: ハードウェア障害が発生した場合(TPU ワーカーのクラッシュなど)、Pathways システムはこれを検出し、そのハードウェアに保存されているデータに次回アクセスしたときに、例外を介してユーザーのトレーニング ジョブに通知します。この通知によってワークロードがクラッシュすることはありません。コードで通知を処理し、リソースを再構成して処理を続行するか、正常に終了できます。
    • ユーザー定義の伸縮性ハンドラ: ユーザーのモデルコードでこの例外を処理できる必要があります。これが「モデル固有の復元」です。
      • スナップショット: 最も一般的な方法は、モデルの状態のスナップショットを定期的に保存することです。障害が発生した場合は、最新のスナップショットから読み込んでトレーニングを再開できます。
      • 再構成: 使用可能なスライスの数に合わせてトレーニング ジョブを再構成する必要がある場合があります。たとえば、1 つのスライスが動作を停止した場合、代替が利用可能になるまでアクティブなスライスの数を 1 つ減らすことができます。詳細については、Elastic ハンドラをご覧ください。
      • データ/計算グラフの更新: コードは、必要に応じて計算グラフを再作成することで、計算に使用できるデバイス数の変更を処理する必要があります。これには、データの再パーティショニングやモデルの再コンパイルが必要になる場合があります。
    • 復元における Pathways の役割: Pathways は、ユーザー定義の再構成をサポートするプリミティブを提供します。
      • スライスの置換: 失敗したスライスが置換された場合、新しいスライスが利用可能になった時点でクライアントに通知できます。コードは、この新しいスライスを使用するように再構成できます。
      • 透過的な復元: Pathways は、クラスタの正常な部分への接続の再確立など、復元の低レベルの詳細を処理します。
    • pathwaysutils のユーティリティ: pathways-utils で定義された Pathways ユーティリティのセット。

    エラスティック ハンドラを実装する

    記述する必要があるコードのほとんどは、ユーザー定義の弾性ハンドラにあります。このハンドラは、TPU スライスが使用不可になるなどのエラスティック イベントに応答して、メッシュを再作成し、トレーニング ループを再初期化します。

    ワークロードはそれぞれ異なります。エラスティック ハンドラの複雑さは、ワークロードの複雑さに応じてスケーリングされる可能性があります。ハンドラの入力と出力は、トレーニング ループを再初期化するために必要な最小限の引数と戻り値にする必要があります。

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    トレーニング ループを更新する

    トレーニング ループに次の変更を加える必要があります。

    1. エラスティック マネージャーを作成する
    2. jax.errors.JaxRuntimeError を処理する try-except ブロック内にトレーニング ループをラップする
    3. jax.errors.JaxRuntimeError ハンドラ内で、maybe_reshard_down を呼び出します。エラーがエラスティック イベントに関連している場合、エラスティック マネージャーはシャードダウンを再実行します。それ以外の場合は、エラーを再発生させます。
    4. トレーニング ループの最後に maybe_snapshotmaybe_reshard_up を呼び出す
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    Elastic マネージャーを構成する

    エラスティック マネージャーは、いくつかの異なる方法で構成できます。スナップショットの頻度は、スナップショット期間によって決まります。スナップショット期間は、エラスティック イベントによって失われたステップの平均数に影響します。再シャーディング チェック期間は、トレーニング ループがスライスの可用性をポーリングする頻度を決定します。max_elastic_down_event_count を使用すると、トレーニング ループでサポートするスライス損失によるエラスティック イベントの数を設定できます。max_reshard_retry_count は、エラスティック マネージャーが再シャーディングを再試行する回数を指定します。マネージャーはシングルトン オブジェクトであり、1 回だけ作成する必要があります。

    スナップショット

    エラスティック マネージャーの構成に基づいて、関数はホストメモリにデータをスナップショットします。このデータは、エラスティック イベント中にエラスティック ハンドラで使用できます。

    シャーディングを減らす

    jax.errors.JaxRuntimeError をキャッチすると、Pathways は、エラーがスライスの損失によるエラスティック イベントによるものかどうかを確認します。その場合、成功するか最大再試行回数に達するまで、ループ内でエラスティック ハンドラを呼び出します。エラーがエラスティック イベントによるものでない場合は、エラーが再度発生します。エラスティック ハンドラの戻り値は、呼び出し元に渡されます。

    シャーディングを増やす

    エラスティック マネージャーの構成に基づいて、使用できないスライスがある場合、Pathways は追加のスライスが使用可能になったかどうかを確認します。その場合は、スナップショットを直ちに保存し(現在のステップの既存のスナップショットがまだ取得されていない場合)、成功するか、最大再試行回数に達するまで、ループ内でエラスティック ハンドラを呼び出します。再シャーディングが発生した場合、エラスティック ハンドラの戻り値は呼び出し元に渡されます。それ以外の場合は、None が返されます。

    ホットスワップ

    ホットスワップとは、優先度の高いジョブが優先度の低いジョブからリソースをすばやく引き継ぎ、ダウンタイムを最小限に抑えて迅速な復元を保証する GKE JobSet API の機能です。

    JobSet が作成されると、GKE は JobSet 構成で指定された複数のスライスにワークロードをスケジュールします。1 つ以上のスライスでハードウェア障害が発生すると、影響を受ける Pod は失敗としてマークされます。この Jobset を再スケジュールするときに、優先度の低い Job に使用できる予備のスライスを GKE クラスタに保持することを選択した場合、JobSet システムは、優先度の高い Job の失敗したスライスのワークロードを、同じ GKE クラスタ内の優先度の低い Job で使用されている予備のスライスに再マッピングします。通常、この再マッピングは 1 分以内に完了します。

    JobSet の再起動時に、次の状況でホットスワップが発生する可能性があります。

    1. デフォルト モード: 同じクラスタ内に予備のアイドル状態の TPU スライスがある場合、Kubernetes スケジューラは、失敗したスライスが修復されるのを待つのではなく、再起動されたジョブをこれらのスライスにスケジュールすることを優先します。これにより、復元が高速化されます。
    2. 異種ワークロード: 構成済みの Kubernetes PriorityClass を使用して複数のワークロードを実行しているクラスタでは、再起動された JobSet によってホットスワップがトリガーされる可能性があります。再起動されたジョブのアフィニティが優先度の低いジョブのリソースと一致する場合、Kubernetes は優先度の低いジョブをプリエンプトし、優先度の高いジョブをすぐに開始できるようにします。たとえば、PriorityClass を使用して、異なる優先度で Pathways ワーカー Pod を構成できます。

    クラスタで優先度を使用するには、次のように優先度クラスを定義します。

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    この YAML を GKE クラスタに適用します。

    kubectl apply -f high-prior-job.yaml
    

    次に、pathways-worker Pod の podspec に次のテキストを追加して、新しい優先度クラスを Pathways ワーカー ジョブに適用します。

    priorityClassName: high-prior-job
    

    次のステップ