途径可通过以下方式提供恢复能力优势:
- 暂停-恢复:在面对计划内的中断(例如抢占通知)时,无需用户编写任何自定义抢占处理代码即可实现容错。
- 弹性训练:在面对意外的硬件故障时具有容错能力,不会导致客户端崩溃,但需要用户编写特定于模型的恢复代码。
准备工作
请确保您已备妥:
暂停-恢复
通常,GKE 会在抢占加速器 Pod 之前向该 Pod 发送抢占通知。默认情况下,所有云部署都启用了 Pathways 抢占容忍度,并且 Pathways 加速器作业会监听这些通知。
当抢占通知到达时,Pathways 首先会确定当前工作负载是否可恢复,即 Pathways 是否可以透明地保存和恢复工作负载。如果您的 ML 工作负载正在使用加速器,那么在 GKE 逐出加速器作业之前,它会尝试通过将当前状态写入 Cloud Storage 等永久性存储空间来透明地暂停您的 ML 工作负载。当 GKE 稍后重新调度作业时,Pathways 会通过回读其持久化状态来恢复机器学习工作负载。
如果工作负载无法恢复,Pathways 会关闭加速器作业,并在配置了弹性训练的情况下将故障转发给您的作业。如果未配置弹性训练,GKE 会根据 JobSet 重启政策重启整个工作负载。
使用 JAX 定义的典型机器学习工作负载依赖于无状态 Pathways XLA 组件,这些组件可以使用高带宽内存 (HBM) 快照进行恢复。某些机器学习工作负载(例如使用 JAX 并置 Python API 定义的工作负载)依赖于有状态的 Pathways 组件;这些组件无法恢复。
弹性训练
弹性训练功能可让您的训练作业在发生硬件故障时继续运行。这是通过结合使用 Pathways 系统功能和用户定义的模型恢复逻辑来实现的:
- 检测故障:当发生硬件故障(例如,TPU 工作器崩溃)时,Pathways 系统会检测到此故障,并在下次访问位于该硬件上的数据时通过异常通知用户的训练作业。此通知不会导致工作负载崩溃;它允许您的代码处理通知并重新配置资源,以继续处理或正常退出。
- 用户定义的弹性处理程序:用户的模型代码需要能够处理此异常。这使得它成为“特定于型号的恢复”。
- 拍摄快照:最常见的方法是定期保存模型状态的快照。发生故障时,您可以从最近的快照加载以恢复训练。
- 重新配置:您可能需要重新配置训练作业,以根据可用的切片数量进行调整。例如,如果某个切片停止工作,您可能会将有效切片的数量减少一个,直到有替代切片可用为止。如需了解详情,请参阅 Elastic 处理程序。
- 数据/计算图更新:您的代码需要通过根据需要重新创建计算图来处理可用于计算的设备数量的任何变化。这可能涉及重新划分数据或重新编译模型。
- Pathways 在恢复中的作用:Pathways 提供支持用户定义重新配置的原语:
- 切片替换:如果替换了失败的切片,则可以在新切片可用时通知客户端。然后,您的代码可以重新配置为使用此新切片。
- 透明恢复:Pathways 会处理恢复的较低级别详细信息,例如重新建立与集群健康部分的连接。
- pathwaysutils 中的实用程序:一组在 pathways-utils 中定义的 Pathways 实用程序。
实现弹性处理程序
您需要编写的大部分代码都将位于用户定义的弹性处理程序中。此处理程序通过重新创建网格并重新初始化训练循环来响应弹性事件(例如 TPU 切片变得不可用)。
每项工作负载都是独一无二的。弹性处理程序的复杂性可能会随工作负载的复杂性而变化。处理程序的输入和输出应为重新初始化训练循环所需的最少实参和返回值。
def elastic_handler(elastic_utils, *args, **kwargs): mesh = initialize_mesh(**kwargs["mesh_kwargs"]) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"]) step, snapshot = elastic_utils.get_next_snapshot() state = initial_state.replace(**snapshot) return state, step, mesh, jitted_train_step, other_variables更新训练循环
您需要对训练循环进行以下更改:
- 创建弹性管理器
- 将训练循环封装在处理
jax.errors.JaxRuntimeError的 try-except 块中 - 在
jax.errors.JaxRuntimeError处理程序中,调用maybe_reshard_down。如果错误与弹性事件相关,弹性管理器将重新分片;否则,将重新引发该错误。 - 在训练循环结束时调用
maybe_snapshot和maybe_reshard_up
import pathwaysutils from pathwaysutils.elastic import manager pathwaysutils.initialize() def initialize_mesh(**kwargs): ... def initialize_training_loop(**kwargs): ... def train_loop( final_step, elastic_manager, mesh_kwargs, initialize_training_loop_kwargs, ): mesh = initialize_mesh(**mesh_kwargs) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **initialize_training_loop_kwargs) step = initial_step while step < final_step: try: state = jitted_train_step(state) elastic_manager.maybe_snapshot(step=step, snapshot=state) handler_returns = elastic_manager.maybe_reshard_up( step=step, snapshot=state, elastic_handler=elastic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns step += 1 except jax.errors.JaxRuntimeError as error: handler_returns = elastic_manager.maybe_reshard_down( error=error, elastic_handler=elastic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns return state def main(): elastic_manager = manager.Manager( devices=jax.devices(), snapshot_period=10, snapshot_buffer_size=1, reshard_check_period=5, max_elastic_down_event_count=10, max_reshard_retry_count=3, ) train_loop(100, elastic_manager, {}, {})配置弹性管理器
您可以通过几种不同的方式配置弹性管理器。拍摄快照的频率由快照周期决定。快照周期会影响因弹性事件而丢失的平均步数。重新分片检查周期决定了训练循环轮询切片可用性的频率。通过
max_elastic_down_event_count,您可以设置训练循环将支持的因切片丢失而导致的弹性事件数量。max_reshard_retry_count用于指定弹性管理器应重试重新分片的次数。该管理器是一个单例对象,应仅创建一次。快照
根据弹性管理器配置,该函数可能会将数据快照到主机内存中,以便在弹性事件期间供弹性处理程序使用。
减少分片
捕获到
jax.errors.JaxRuntimeError后,Pathways 会检查该错误是否是由丢失切片导致弹性事件引起的。如果需要,它将循环调用弹性处理程序,直到成功或达到最大重试次数。如果错误不是由弹性事件引起的,则会再次引发该错误。弹性处理程序的返回值会传递给调用方。增加分片
根据弹性管理器配置,如果存在不可用的切片,Pathways 会检查是否有其他切片变为可用。如果需要,它会立即保存快照(如果当前步骤之前尚未拍摄快照),并循环调用弹性处理程序,直到成功或达到最大重试次数。如果发生重新分片,弹性处理程序的返回值会传递给调用方。否则,返回
None。热插拔
热插拔是指 GKE JobSet API 的一项功能,通过该功能,优先级较高的作业可以快速接管优先级较低的作业的资源,从而最大限度地减少停机时间并确保更快地恢复。
创建 JobSet 后,GKE 会按照 JobSet 配置中的指定,在多个切片上调度工作负载。如果一个或多个切片上发生硬件故障,受影响的 Pod 会被标记为失败。重新调度此 Jobset 时,如果您选择在 GKE 集群中保留一个可用于低优先级作业的备用 slice,则 JobSet 系统会将失败的高优先级作业 slice 的工作负载重新映射到同一 GKE 集群中由低优先级作业使用的备用 slice。此重新映射过程通常用时不到 1 分钟。
在 JobSet 重启时,可能会在以下情况下发生热插拔:
- 默认模式:如果同一集群中有空闲的 TPU 切片,Kubernetes 调度器会优先将重启的作业调度到这些切片上,而不是等待故障切片修复。这样可以更快地恢复。
- 异构工作负载:在运行多个工作负载且配置了 Kubernetes PriorityClass 的集群中,重新启动的 JobSet 可能会触发热交换。如果重新启动的作业的亲和性与低优先级作业的资源相匹配,Kubernetes 会抢占低优先级作业,从而允许高优先级作业立即启动。例如,您可以使用
PriorityClass为 Pathways 工作器 pod 配置不同的优先级。
如需在集群中使用优先级,请定义一个优先级类,例如:
kind: PriorityClass metadata: name: high-prior-job value: 2000 globalDefault: false description: "This priority class should be used for high priority job."将此 YAML 应用到您的 GKE 集群:
kubectl apply -f high-prior-job.yaml接下来,通过将以下文本添加到
pathways-workerPod 的 podspec 中,将新的优先级类附加到 Pathways worker 作业。priorityClassName: high-prior-job后续步骤