Pathways を使用した復元力のあるトレーニング

Pathways は、次のような方法で復元力のメリットを提供します。

  • 一時停止と再開: ユーザーがカスタムのプリエンプション処理コードを記述しなくても、 プリエンプション通知などの計画的な中断に対応できます。
  • Elastic Training: 計画外のハードウェア障害が発生しても、クライアントがクラッシュすることなく対応できますが、ユーザーはモデル固有の復旧コードを記述する必要があります。

始める前に

インストールに必要なもの:

一時停止と再開

通常、GKE は Pod がプリエンプトされる前に、アクセラレータ Pod にプリエンプション通知を送信します。Pathways のプリエンプション許容度は、すべてのクラウド デプロイでデフォルトで有効になっており、Pathways アクセラレータ ジョブはこれらの通知をリッスンします。

プリエンプション通知が届くと、Pathways はまず、現在のワークロードが復元可能かどうか、つまり Pathways がワークロードを透過的に保存して復元できるかどうかを判断します。復元可能な場合は、GKE がアクセラレータ ジョブをエビクションする前に、現在の状態を Cloud Storage などの永続ストレージに書き込むことで、ML ワークロードを透過的に一時停止しようとします。GKE が後でジョブを再スケジュールすると、Pathways は永続化された状態を読み取って ML ワークロードを再開します。

ワークロードが復元できない場合、Pathways はアクセラレータ ジョブをシャットダウンし、Elastic トレーニング が構成されている場合は、ジョブに障害を転送します。Elastic トレーニングが構成されていない場合、GKE は JobSet の再起動ポリシーに基づいてワークロード全体を再起動します。

JAX を使用して定義された一般的な ML ワークロードは、ステートレス Pathways XLA コンポーネントに依存しています。これらのコンポーネントは、高帯域幅メモリ(HBM)スナップショットを使用して復元できます。JAX のコロケーションされた Python API を使用して定義された ML ワークロードなど、一部の ML ワークロードはステートフル Pathways コンポーネントに依存しています。これらは復元できません。

弾力性のあるトレーニング

弾力性のあるトレーニングを使用すると、ハードウェア障害が発生した場合でもトレーニング ジョブを続行できます。これは、Pathways システムの機能とユーザー定義のモデル復旧ロジックの組み合わせによって実現されます。

  • 障害の検出: ハードウェア障害が発生した場合(TPU ワーカーがクラッシュした場合など)、Pathways システムはこれを検出し、そのハードウェアに配置されていたデータに次回アクセスしたときに例外を介してユーザーの トレーニング ジョブに通知します。この通知によってワークロードがクラッシュすることはありません。コードで通知を処理し、リソースを再構成して処理を続行するか、正常に終了できます。
  • ユーザー定義の伸縮ハンドラ: ユーザーのモデルコードで この例外を処理できる必要があります。これが「モデル固有の復旧」です。
    • スナップショット: 最も一般的な方法は、モデルの状態のスナップショットを定期的に保存することです。障害が発生した場合は、最新のスナップショットから読み込んでトレーニングを再開できます。
    • 再構成: 使用可能なスライスの数に合わせてトレーニング ジョブを 再構成する必要がある場合があります。たとえば、1 つのスライスが動作を停止した場合、代替が利用可能になるまでアクティブなスライスの数を 1 つ減らすことができます。詳細については、伸縮ハンドラをご覧ください。
    • データグラフ/計算グラフの更新: 必要に応じて計算グラフを再作成することで、計算に使用できるデバイスの数の変更をコードで処理する必要があります。これには、データの再パーティショニングやモデルの再コンパイルが必要になる場合があります。
  • 復旧における Pathways の役割 : Pathways は、 ユーザー定義の再構成 をサポートするプリミティブを提供します。
    • スライスの置換: 失敗したスライスが置き換えられた場合、クライアントは新しいスライスが利用可能になったときに 通知できます。その後、コードを再構成してこの新しいスライスを使用できます。
    • 透過的な復旧: Pathways は、クラスタの正常な部分への接続の再確立など、復旧の下位レベルの詳細を処理します。
  • pathwaysutils のユーティリティ: pathways-utils で定義された Pathways ユーティリティのセット。

伸縮ハンドラを実装する

記述する必要があるコードのほとんどは、ユーザー定義の伸縮ハンドラに記述します。このハンドラは、TPU スライスが使用できなくなるなどの伸縮イベントに応答して、メッシュを再作成し、トレーニング ループを再初期化します。

ワークロードはそれぞれ異なります。伸縮ハンドラの複雑さは、ワークロードの複雑さに応じてスケーリングされる可能性があります。ハンドラの入力と出力は、トレーニング ループを再初期化するために必要な最小限の引数と戻り値にする必要があります。

def elastic_handler(elastic_utils, *args, **kwargs):
  mesh = initialize_mesh(**kwargs["mesh_kwargs"])
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])

  step, snapshot = elastic_utils.get_next_snapshot()
  state = initial_state.replace(**snapshot)

  return state, step, mesh, jitted_train_step, other_variables

トレーニング ループを更新する

トレーニング ループに次の変更を加える必要があります。

  1. 伸縮マネージャーを作成する
  2. jax.errors.JaxRuntimeError を処理する try-except ブロックでトレーニング ループをラップする
  3. jax.errors.JaxRuntimeError ハンドラ内で maybe_reshard_down を呼び出す。エラーが伸縮イベントに関連している場合、伸縮マネージャーは再シャーディングを行います。それ以外の場合は、エラーを再発生させます。
  4. トレーニング ループの最後に maybe_snapshotmaybe_reshard_up を呼び出す
import pathwaysutils
from pathwaysutils.elastic import manager

pathwaysutils.initialize()

def initialize_mesh(**kwargs):
  ...


def initialize_training_loop(**kwargs):
  ...


def train_loop(
    final_step,
    elastic_manager,
    mesh_kwargs,
    initialize_training_loop_kwargs,
):
  mesh = initialize_mesh(**mesh_kwargs)
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **initialize_training_loop_kwargs)

  step = initial_step
  while step < final_step:
    try:
      state = jitted_train_step(state)

      elastic_manager.maybe_snapshot(step=step, snapshot=state)
      handler_returns = elastic_manager.maybe_reshard_up(
          step=step,
          snapshot=state,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns
      step += 1
    except jax.errors.JaxRuntimeError as error:
      handler_returns = elastic_manager.maybe_reshard_down(
          error=error,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns

  return state


def main():
  elastic_manager = manager.Manager(
      devices=jax.devices(),
      snapshot_period=10,
      snapshot_buffer_size=1,
      reshard_check_period=5,
      max_elastic_down_event_count=10,
      max_reshard_retry_count=3,
  )

  train_loop(100, elastic_manager, {}, {})

伸縮マネージャーを構成する

伸縮マネージャーは、いくつかの方法で構成できます。スナップショットの頻度は、スナップショット期間によって決まります。スナップショット期間は、伸縮イベントによって失われるステップの平均数に影響します。再シャーディング チェック期間は、トレーニング ループがスライスの可用性をポーリングする頻度を決定します。 max_elastic_down_event_count を使用すると、スライスの損失による伸縮イベントの数を設定できます。トレーニング ループでサポートされます。max_reshard_retry_count は、伸縮マネージャーが再シャーディングを再試行する回数を指定します。マネージャーはシングルトン オブジェクトであり、1 回だけ作成する必要があります。

スナップショット

伸縮マネージャーの構成に基づいて、伸縮イベント中に伸縮ハンドラで使用できるホストメモリにデータをスナップショットできます。

シャーディングを減らす

jax.errors.JaxRuntimeError をキャッチすると、Pathways は、エラーがスライスの損失による伸縮イベントによるものかどうかを確認します。その場合は、成功するか最大再試行回数に達するまで、伸縮ハンドラをループで呼び出します。エラーが伸縮イベントによるものでない場合は、エラーが再度発生します。伸縮ハンドラの戻り値は、呼び出し元に渡されます。

シャーディングを増やす

伸縮マネージャーの構成に基づいて、使用できないスライスがある場合、Pathways は追加のスライスが使用可能になったかどうかを確認します。使用可能になった場合は、すぐにスナップショットを保存し(現在のステップの既存のスナップショットがまだ作成されていない場合)、成功するか最大再試行回数に達するまで、伸縮ハンドラをループで呼び出します。再シャーディングが発生した場合、伸縮ハンドラの戻り値は呼び出し元に渡されます。それ以外の場合は、None が返されます。

ホットスワップ

ホットスワップとは、GKE JobSet API の機能で、優先度の高いジョブが優先度の低いジョブからリソースを迅速に引き継ぎ、ダウンタイムを最小限に抑えて迅速な復旧を実現します。

JobSet が作成されると、GKE は JobSet 構成で指定されたとおりに、複数のスライスにワークロードをスケジュールします。1 つ以上のスライスでハードウェア障害が発生すると、影響を受ける Pod は失敗としてマークされます。この Jobset を再スケジュールするときに、優先度の低いジョブに使用できる予備のスライスを GKE クラスタに保持することを選択した場合、JobSet システムは、優先度の高いジョブの失敗したスライスのワークロードを、同じ GKE クラスタ内の優先度の低いジョブで使用されている予備のスライスに再マッピングします。通常、この再マッピングには 1 分もかかりません。

JobSet の再起動時に、次のような状況でホットスワップが発生する可能性があります。

  1. デフォルト モード: 同じ クラスタ内に予備のアイドル TPU スライスがある場合、Kubernetes スケジューラは、失敗したスライスが修復されるのを待つのではなく、再起動された ジョブをこれらのスライスに優先的にスケジュールします。これにより、復旧が迅速になります。
  2. 異種ワークロード: Kubernetes PriorityClass が構成された複数のワークロードを実行しているクラスタでは、再起動された JobSet によってホットスワップがトリガーされる可能性があります。再起動されたジョブのアフィニティが優先度の低いジョブのリソースと一致する場合、Kubernetes は優先度の低いジョブをプリエンプトし、優先度の高いジョブをすぐに開始できます。たとえば、PriorityClass を使用して、Pathways ワーカー Pod に異なる優先度を構成できます。

クラスタで優先度を使用するには、次のように優先度クラスを定義します。

kind: PriorityClass
metadata:
  name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."

この YAML を GKE クラスタに適用します。

kubectl apply -f high-prior-job.yaml

次に、pathways-worker Pod の podspec に次のテキストを追加して、新しい優先度クラスを Pathways ワーカー ジョブにアタッチします。

priorityClassName: high-prior-job

次のステップ