通过 Pathways 进行弹性训练

Pathways 通过以下方式提供弹性优势:

  • 暂停-恢复:在面对 抢占通知等计划内中断时,无需用户编写任何自定义抢占 处理代码即可容忍中断。
  • 弹性训练:在面对计划外硬件故障时 无需导致客户端崩溃,但需要用户编写特定于模型 的恢复代码。

准备工作

请确保您已备妥:

暂停-恢复

通常,GKE 会在抢占加速器 Pod 之前向其发送抢占通知。默认情况下,所有云部署都启用了 Pathways 抢占容忍功能,Pathways 加速器作业会监听这些通知。

收到抢占通知后,Pathways 首先会确定当前工作负载是否可恢复,即 Pathways 是否可以透明地保存和恢复工作负载。如果可以,则它会尝试通过将机器学习工作负载的当前状态写入永久性存储(例如 Cloud Storage),在 GKE 逐出加速器作业之前透明地暂停机器学习工作负载。当 GKE 稍后重新调度作业时,Pathways 会通过读取 其持久化状态来恢复机器学习工作负载。

如果工作负载不可恢复,Pathways 会关闭加速器 作业,并在配置了弹性训练 的情况下将故障转发给作业。如果未配置弹性训练,GKE 会根据 JobSet 重启政策重启整个工作负载。

使用 JAX 定义的典型机器学习工作负载依赖于无状态 Pathways XLA 组件,这些组件可以使用高带宽内存 (HBM) 快照进行恢复。某些机器学习 工作负载(例如使用JAX 并置 Python API 定义的工作负载)依赖于有状态 Pathways 组件;这些组件不可恢复。

弹性训练

借助弹性训练,即使发生硬件故障,训练作业也能继续进行。这是通过 Pathways 系统功能和用户定义的模型恢复逻辑相结合来实现的:

  • 检测故障:当发生硬件故障(例如 TPU 工作器崩溃)时,Pathways 系统会检测到此故障,并在下次访问该 硬件上的数据时通过异常通知用户的 训练作业。此通知不会导致工作负载崩溃;它允许您的代码处理通知并重新配置资源,以继续处理或正常退出。
  • 用户定义的弹性处理程序:用户的模型代码需要能够处理此异常。这使其成为“特定于模型的恢复”。
    • 快照:最常见的方法是定期保存模型状态的快照 。发生故障时,您可以从最新的快照加载以恢复训练。
    • 重新配置:您可能需要重新配置训练作业,以 适应可用切片的数量。例如,如果一个切片停止工作,您可能会将活跃切片的数量减少一个,直到有替代切片可用为止。如需了解详情,请参阅弹性处理程序
    • 数据/计算图更新:您的代码需要通过根据需要重新创建计算图来处理可用于计算的设备数量的任何变化。这可能涉及重新分区数据或重新编译模型。
  • Pathways 在恢复中的作用:Pathways 提供用于支持 用户定义重新配置的原语:
    • 切片替换:如果替换了失败的切片,则可以在新切片可用后 通知客户端。然后,您的代码可以重新配置以使用此新切片。
    • 透明恢复:Pathways 处理 恢复的较低级别详细信息,例如重新建立与 集群健康部分的连接。
  • pathwaysutils 中的实用程序:在 pathways-utils中定义的一组 Pathways 实用程序。

实现弹性处理程序

您必须编写的大部分代码都位于用户定义的弹性处理程序中。此处理程序通过重新创建网状网并重新初始化训练循环来响应弹性事件(例如 TPU 切片变得不可用)。

每个工作负载都是独一无二的。弹性处理程序的复杂性可能会随着工作负载的复杂性而变化。处理程序的输入和输出应该是重新初始化训练循环所需的最少参数和返回值。

def elastic_handler(elastic_utils, *args, **kwargs):
  mesh = initialize_mesh(**kwargs["mesh_kwargs"])
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])

  step, snapshot = elastic_utils.get_next_snapshot()
  state = initial_state.replace(**snapshot)

  return state, step, mesh, jitted_train_step, other_variables

更新训练循环

您需要对训练循环进行以下更改:

  1. 创建弹性管理器
  2. 将训练循环封装在处理 jax.errors.JaxRuntimeError 的 try-except 块中
  3. jax.errors.JaxRuntimeError 处理程序中,调用 maybe_reshard_down。如果错误与弹性事件相关,弹性管理器将重新分片,否则将重新引发该错误。
  4. 在训练循环结束时调用 maybe_snapshotmaybe_reshard_up
import pathwaysutils
from pathwaysutils.elastic import manager

pathwaysutils.initialize()

def initialize_mesh(**kwargs):
  ...


def initialize_training_loop(**kwargs):
  ...


def train_loop(
    final_step,
    elastic_manager,
    mesh_kwargs,
    initialize_training_loop_kwargs,
):
  mesh = initialize_mesh(**mesh_kwargs)
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **initialize_training_loop_kwargs)

  step = initial_step
  while step < final_step:
    try:
      state = jitted_train_step(state)

      elastic_manager.maybe_snapshot(step=step, snapshot=state)
      handler_returns = elastic_manager.maybe_reshard_up(
          step=step,
          snapshot=state,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns
      step += 1
    except jax.errors.JaxRuntimeError as error:
      handler_returns = elastic_manager.maybe_reshard_down(
          error=error,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns

  return state


def main():
  elastic_manager = manager.Manager(
      devices=jax.devices(),
      snapshot_period=10,
      snapshot_buffer_size=1,
      reshard_check_period=5,
      max_elastic_down_event_count=10,
      max_reshard_retry_count=3,
  )

  train_loop(100, elastic_manager, {}, {})

配置弹性管理器

弹性管理器可以通过几种不同的方式进行配置。快照的频率由快照周期决定。快照周期会影响因弹性事件而丢失的平均步数。重新分片检查周期决定了训练循环轮询切片可用性的频率。 借助 max_elastic_down_event_count,您可以设置训练循环将支持的因切片丢失而发生的弹性事件的数量。max_reshard_retry_count 指定了弹性管理器应重试重新分片的次数。管理器是单例对象,应仅创建一次。

快照

根据弹性管理器配置,该函数可能会将数据快照到主机内存中,弹性处理程序在弹性事件期间可以使用这些数据。

减少分片

捕获 jax.errors.JaxRuntimeError 后,Pathways 将检查错误是否是因切片丢失而发生的弹性事件所致。如果是,它将循环调用弹性处理脚本,直到成功或达到最大重试次数。如果错误不是因弹性事件所致,则会再次引发该错误。弹性处理脚本的返回值将传递给调用方。

增加分片

根据弹性管理器配置以及是否存在不可用的切片,Pathways 将检查是否有其他切片可用。如果是,它将立即保存快照(如果尚未为当前步拍摄现有快照),并循环调用弹性处理程序,直到成功或达到最大重试次数。如果发生重新分片,弹性处理程序的返回值将传递给调用方。否则,将返回 None

热插拔

热插拔是指 GKE JobSet API 的一项功能,借助该功能,优先级较高的作业可以快速接管优先级较低的作业的资源,从而最大限度地减少停机时间并确保更快的恢复速度。

创建 JobSet 后,GKE 会根据 JobSet 配置中的规定,在多个切片上调度工作负载。如果一个或多个切片发生硬件故障,受影响的 Pod 将被标记为失败。在重新调度此 JobSet 时,如果您选择在 GKE 集群中保留一个备用切片,该切片可用于优先级较低的作业,则 JobSet 系统会将优先级较高的作业的失败切片的工作负载重新映射到同一 GKE 集群中优先级较低的作业正在使用的备用切片上。此重新映射通常用时不到 1 分钟。

JobSet 重启后,可能会在以下情况下发生热插拔:

  1. 默认模式:如果同一 集群中有备用空闲 TPU 切片可用,Kubernetes 调度器将优先将重启的 作业调度到这些切片上,而不是等待修复失败的切片。这样可以加快恢复速度。
  2. 异构工作负载:在运行多个工作负载且配置了 Kubernetes PriorityClass 的集群中,重启的 JobSet 可以触发热插拔。如果重启的作业的亲和性与优先级较低的作业的资源匹配,Kubernetes 会抢占优先级较低的作业,从而允许优先级较高的作业立即启动。例如,您可以使用 PriorityClass 为 Pathways 工作器 Pod 配置不同的优先级。

如需在集群中使用优先级,请定义优先级类,例如:

kind: PriorityClass
metadata:
  name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."

将此 YAML 应用到您的 GKE 集群:

kubectl apply -f high-prior-job.yaml

接下来,通过将以下文本添加到 pathways-worker Pod 的 podspec 中,将新的优先级类附加到 Pathways 工作器作业。

priorityClassName: high-prior-job

后续步骤