Pathways を使用した復元力のあるトレーニング

Pathways には、次のような復元性のメリットがあります。

  • 停止と再開: ユーザーがカスタムのプリエンプション 処理コードを記述しなくても、 プリエンプション通知などの計画的な中断に対応できます。
  • Elastic Training: 計画外のハードウェア障害に対応できます クライアントがクラッシュすることなく、ただしユーザーはモデル 固有の復旧コードを記述する必要があります。

始める前に

インストールに必要なもの:

停止と再開

通常、GKE は Pod がプリエンプトされる前に、プリエンプション通知をアクセラレータ Pod に送信します。Pathways のプリエンプション許容度は、すべてのクラウド デプロイでデフォルトで有効になっており、Pathways アクセラレータ ジョブはこれらの通知をリッスンします。

プリエンプション通知が届くと、Pathways はまず現在のワークロードを復元できるかどうか、つまり Pathways がワークロードを透過的に保存して復元できるかどうかを判断します。復元できる場合は、GKE がアクセラレータ ジョブをエビクションする前に、現在の状態を Cloud Storage などの永続ストレージに書き込むことで、ML ワークロードを透過的に一時停止しようとします。GKE が後でジョブを再スケジュールすると、Pathways は永続化された状態を読み取って ML ワークロードを再開します。

ワークロードを復元できない場合、Pathways はアクセラレータ ジョブをシャットダウンし、弾力性のあるトレーニング が構成されている場合は、障害をジョブに転送します。弾力性のあるトレーニングが構成されていない場合、GKE は JobSet の再起動ポリシーに基づいてワークロード全体を再起動します。

JAX を使用して定義された一般的な ML ワークロードは、ステートレスの Pathways XLA コンポーネントに依存しています。これらのコンポーネントは、高帯域幅メモリ(HBM)スナップショットを使用して復元できます。JAX のコロケーションされた Python API を使用して定義された ML ワークロードなど、一部の ML ワークロードはステートフル Pathways コンポーネントに依存しています。これらは復元できません。

弾力性のあるトレーニング

弾力性のあるトレーニングを使用すると、ハードウェア障害が発生した場合でもトレーニング ジョブを続行できます。これは、Pathways システムの機能とユーザー定義のモデル復旧ロジックを組み合わせることで実現されます。

  • 障害の検出: ハードウェア障害(TPU ワーカーのクラッシュなど)が発生すると、Pathways システムがこれを検出し、そのハードウェアに配置されていたデータに次回アクセスしたときに例外を介してユーザーの トレーニング ジョブに通知します。この通知によってワークロードがクラッシュすることはありません。コードで通知を処理し、リソースを再構成して処理を続行するか、正常に終了できます。
  • ユーザー定義の弾力性ハンドラ: ユーザーのモデルコードで この例外を処理できる必要があります。これが「モデル固有の復旧」です。
    • スナップショット: 最も一般的な方法は、モデルの状態のスナップショットを定期的に保存することです。障害が発生した場合は、最新のスナップショットから読み込んでトレーニングを再開できます。
    • 再構成: 使用可能なスライスの数に合わせてトレーニング ジョブを 再構成する必要がある場合があります。たとえば、1 つのスライスが動作を停止した場合は、代替スライスが使用可能になるまで、アクティブなスライスの数を 1 つ減らすことができます。詳細については、弾力性ハンドラをご覧ください。
    • データグラフ/計算グラフの更新: 必要に応じて計算グラフを再作成することで、計算に使用できるデバイス数の変更をコードで処理する必要があります。これには、データの再パーティショニングやモデルの再コンパイルが必要になる場合があります。
  • 復旧における Pathways の役割 : Pathways は、 ユーザー定義の再構成 をサポートするプリミティブを提供します。
    • スライスの置換: 障害が発生したスライスが置換された場合、クライアントは 新しいスライスが使用可能になったときに通知できます。その後、コードを再構成してこの新しいスライスを使用できます。
    • 透過的な復旧: Pathways は、クラスタの正常な部分への接続の再確立など、復旧の下位レベルの詳細を処理します。
  • pathwaysutils のユーティリティ: pathways-utils で定義された Pathways ユーティリティのセット。

弾力性ハンドラを実装する

記述する必要があるコードのほとんどは、ユーザー定義の弾力性ハンドラに記述します。このハンドラは、TPU スライスが使用できなくなるなどの弾力性イベントに応答して、メッシュを再作成し、トレーニング ループを再初期化します。

ワークロードはそれぞれ異なります。弾力性ハンドラの複雑さは、ワークロードの複雑さに応じてスケーリングできます。ハンドラの入力と出力は、トレーニング ループを再初期化するために必要な最小限の引数と戻り値にする必要があります。

def elastic_handler(elastic_utils, *args, **kwargs):
  mesh = initialize_mesh(**kwargs["mesh_kwargs"])
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])

  step, snapshot = elastic_utils.get_next_snapshot()
  state = initial_state.replace(**snapshot)

  return state, step, mesh, jitted_train_step, other_variables

トレーニング ループを更新する

トレーニング ループに次の変更を加える必要があります。

  1. 弾力性マネージャーを作成する
  2. jax.errors.JaxRuntimeError を処理する try-except ブロックでトレーニング ループをラップする
  3. jax.errors.JaxRuntimeError ハンドラ内で maybe_reshard_down を呼び出す。エラーが弾力性イベントに関連している場合、弾力性マネージャーは再シャーディングを行います。それ以外の場合は、エラーを再発生させます。
  4. トレーニング ループの最後に maybe_snapshotmaybe_reshard_up を呼び出す
import pathwaysutils
from pathwaysutils.elastic import manager

pathwaysutils.initialize()

def initialize_mesh(**kwargs):
  ...


def initialize_training_loop(**kwargs):
  ...


def train_loop(
    final_step,
    elastic_manager,
    mesh_kwargs,
    initialize_training_loop_kwargs,
):
  mesh = initialize_mesh(**mesh_kwargs)
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **initialize_training_loop_kwargs)

  step = initial_step
  while step < final_step:
    try:
      state = jitted_train_step(state)

      elastic_manager.maybe_snapshot(step=step, snapshot=state)
      handler_returns = elastic_manager.maybe_reshard_up(
          step=step,
          snapshot=state,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns
      step += 1
    except jax.errors.JaxRuntimeError as error:
      handler_returns = elastic_manager.maybe_reshard_down(
          error=error,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns

  return state


def main():
  elastic_manager = manager.Manager(
      devices=jax.devices(),
      snapshot_period=10,
      snapshot_buffer_size=1,
      reshard_check_period=5,
      max_elastic_down_event_count=10,
      max_reshard_retry_count=3,
  )

  train_loop(100, elastic_manager, {}, {})

弾力性マネージャーを構成する

弾力性マネージャーは、いくつかの方法で構成できます。スナップショットの頻度は、スナップショット期間によって決まります。スナップショット期間は、弾力性イベントによって失われるステップの平均数に影響します。再シャーディング チェック期間は、トレーニング ループがスライスの可用性をポーリングする頻度を決定します。 max_elastic_down_event_count を使用すると、スライスの損失による弾力性イベントの数を設定できます。トレーニング ループでサポートされます。max_reshard_retry_count は、弾力性マネージャーが再シャーディングを再試行する回数を指定します。マネージャーはシングルトン オブジェクトであり、1 回だけ作成する必要があります。

スナップショット

弾力性マネージャーの構成に基づいて、関数はデータをホストメモリにスナップショットできます。このデータは、弾力性イベント中に弾力性ハンドラで使用できます。

シャーディングを減らす

jax.errors.JaxRuntimeError をキャッチすると、Pathways は、エラーがスライスの損失による弾力性イベントによるものかどうかを確認します。その場合は、成功するか、最大再試行回数に達するまで、弾力性ハンドラをループで呼び出します。エラーが弾力性イベントによるものでない場合は、エラーが再度発生します。弾力性ハンドラの戻り値は、呼び出し元に渡されます。

シャーディングを増やす

弾力性マネージャーの構成に基づいて、使用できないスライスがある場合、Pathways は追加のスライスが使用可能になったかどうかを確認します。使用可能になった場合は、すぐにスナップショットを保存し(現在のステップの既存のスナップショットがまだ作成されていない場合)、成功するか、最大再試行回数に達するまで、弾力性ハンドラをループで呼び出します。再シャーディングが発生した場合、弾力性ハンドラの戻り値は呼び出し元に渡されます。それ以外の場合は、None が返されます。

ホットスワップ

ホットスワップとは、GKE JobSet API の機能で、優先度の高いジョブが優先度の低いジョブからリソースを迅速に引き継ぎ、ダウンタイムを最小限に抑えて迅速な復旧を実現します。

JobSet が作成されると、GKE は JobSet 構成で指定されたとおりに、複数のスライスにワークロードをスケジュールします。1 つ以上のスライスでハードウェア障害が発生すると、影響を受ける Pod は失敗としてマークされます。この Jobset を再スケジュールするときに、優先度の低いジョブに使用できる予備のスライスを GKE クラスタに保持することを選択した場合、JobSet システムは、優先度の高いジョブの障害が発生したスライスのワークロードを、同じ GKE クラスタ内の優先度の低いジョブで使用されている予備のスライスに再マッピングします。通常、この再マッピングは 1 分以内に完了します。

JobSet の再起動時に、次のような状況でホットスワップが発生する可能性があります。

  1. デフォルト モード: 同じ クラスタ内に予備のアイドル TPU スライスがある場合、Kubernetes スケジューラは、障害が発生したスライスが修復されるのを待つのではなく、再起動された ジョブをこれらのスライスに優先的にスケジュールします。これにより、復旧が迅速になります。
  2. 異種ワークロード: 構成済みの Kubernetes PriorityClass を使用して複数のワークロードを実行しているクラスタでは、再起動された JobSet によってホットスワップがトリガーされる可能性があります。再起動されたジョブのアフィニティが優先度の低いジョブのリソースと一致する場合、Kubernetes は優先度の低いジョブをプリエンプトし、優先度の高いジョブをすぐに開始できます。たとえば、PriorityClass を使用して、Pathways ワーカー Pod に異なる優先度を構成できます。

クラスタで優先度を使用するには、優先度クラスを定義します。例:

kind: PriorityClass
metadata:
  name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."

この YAML を GKE クラスタに適用します。

kubectl apply -f high-prior-job.yaml

次に、pathways-worker Pod の podspec に次のテキストを追加して、新しい優先度クラスを Pathways ワーカー ジョブに接続します。

priorityClassName: high-prior-job

次のステップ