将 JAX 工作负载移植到 Pathways

由于 JAX 与 Pathways 的分布式特性,某些操作可能会因通信开销而无法很好地扩展。虽然 Pathways 通过异步调度等功能最大限度地减少了这些开销,但在将 JAX 工作负载移植到 Pathways 或将 JAX 与 Pathways 工作负载扩展到大量加速器时,您需要注意一些事项。

准备工作

请确保您已备妥:

流程索引

使用 Pathways 的 JAX 会将 Pathways 集群中的所有设备视为本地设备。这简化了设备管理,并允许 JAX 利用所有可用资源。在实践中,这意味着:

  • 对于所有设备,jax.process_index() 始终为 0。
  • jax.devices()jax.local_devices() 会返回整个作业中的所有 TPU 设备。

硬件类型和托管

为获得最佳性能,请将所有 Pathways 组件和用户作业放置在同一 Google Cloud 云地区中。使用大型 CPU,例如 IFRT 代理和资源管理器。我们建议至少使用专用 n2-standard-64,它配备 64 个 vCPU 和 256 GB 内存。

PathwaysUtils

Pathways-utils 是一个基于 Python 的 GitHub 代码库,提供必要的实用程序和工具,可让您在 Pathways on Cloud 架构上简化 JAX 工作负载的部署和执行。此软件包可处理云环境所需的必要调整,让 JAX 开发者能够专注于其核心机器学习工作流,而无需进行最少的平台特定配置。具体而言,它具有以下优势:

  • “代理”JAX 后端:此自定义后端通过设置 JAX_PLATFORMS=proxy 环境变量,使您的 JAX 应用能够使用 Pathways 基础架构。
  • 集成式分析实用程序:可让您了解应用性能的分析功能。通过使用 jax.profiler.start_tracejax.profiler.start_server 等标准 JAX 性能剖析 API,您不仅可以剖析 JAX 代码,还可以剖析底层 Pathways 组件,从而全面了解云环境中的执行情况。
  • 使用 Orbax 进行分布式检查点:一个自定义 Orbax 检查点处理程序,可让您在 Pathways 环境中使用 Orbax 库时使用分布式检查点并恢复检查点。此集成旨在无需对现有的 Orbax 检查点代码进行任何更改即可正常运行,前提是该代码导入了 pathwaysutils
  • 弹性训练基元:提供基础的弹性训练基元,您可以使用这些基元通过 Pathways 构建稳健且可伸缩的训练工作流。借助这些原语,训练作业可以动态适应可用资源的变化,从而提高云环境中的效率和弹性。

检查点

Orbax 经过 Pathways 的全面测试,可用于通过 Cloud Storage 进行分布式检查点设置和恢复。当您在 train.py 中调用 import pathwaysutils; pathwaysutils.initialize() 时,系统会注册一个自定义 ArrayHandler,该 ArrayHandler 通过 IFRT 代理高效处理检查点操作,从而使加速器上的 Pathways 工作器能够直接保存和恢复数据。

同位 Python

并置 Python 是一种开源 JAX API,可让您直接在 TPU 或 GPU 主机上运行用户指定的 Python 代码,这在多控制器 JAX 中更为简单。这样一来,就可以避免在客户端和 TPU 机器之间传输数据,从而执行更多计算密集型任务,例如数据加载和检查点设置。如需配置 Pathways 集群以运行同位 Python JAX API,请按照同位 Python README 中的说明进行操作。这些说明介绍了如何启动与 Pathways worker 并行的同位 Python 边车。

数据加载

在训练期间,我们会反复从数据集中加载批量数据,以馈送给模型。为了避免加速器缺少工作,请务必使用高效的异步数据加载器,将批次分片到多个主机。使用 Pathways 运行训练时,数据加载器会在 CPU 虚拟机上运行(与多控制器设置中使用的 TPU 虚拟机不同),并将数据调度到 TPU 虚拟机。这会导致读取数据时的延迟时间更长,但通过在 CPU 主机上预先读取 X 个批次并将读取的数据异步调度到 TPU,可以部分缓解此问题。如果以中小规模运行,此解决方案就足够了。

为了在扩缩时获得最佳性能,我们强烈建议您使用同位 Python 将输入数据流水线同位,以便直接在加速器上运行数据流水线。这样可以消除 CPU 瓶颈,并利用 TPU 的快速互连进行数据传输。

您可以在 multihost_dataloading.pyRemoteIterator 实现中找到迁移基于 TFDS 的输入流水线的参考实现。此实现通过并置的 Python JAX API 以分布式方式在多控制器 JAX 和 Pathways 上运行。

Jax 版本控制

Pathways 版本与 JAX 版本紧密相关,以确保兼容性和稳定性。为避免潜在问题,请验证您的 Pathways 制品和 JAX 版本是否一致。每个 Pathways 版本都会通过 jax-<version> 形式的标记明确指定兼容的 JAX 版本。

编译缓存

Pathways 持久性编译缓存是一项功能,可让 Pathways 服务器将编译后的 XLA 可执行文件存储在持久性位置(例如 Cloud Storage)中,以避免冗余编译。此功能默认处于启用状态。缓存的位置作为 --gcs_scratch_location 标志传递给资源管理器和 Pathways 工作器容器。为了尽可能降低相关存储费用,缓存会将生命周期政策附加到 Cloud Storage 位置。每个 Cloud Storage 存储桶最多只能有 50 项政策。因此,我们建议在所有工作负载中使用通用的 Cloud Storage 位置。

此缓存类似于 JAX 编译缓存,后者已通过 pathwaysutils.initialize() 针对 Pathways 工作负载停用。

分析

您可以使用 JAX 分析器生成 JAX 程序的轨迹。Pathways 支持两种常见方式:

  • 程序化
    • 以程序化方式从 JAX 代码中捕获性能分析文件
  • 手动
    • 从 JAX 代码启动性能分析器服务器后,按需捕获性能分析文件

在这两种情况下,配置文件都会写入 Cloud Storage 存储桶。系统会在 Cloud Storage 存储桶中创建多个轨迹文件,这些文件可能位于不同的时间戳文件夹下,例如:

  • 调用轨迹的主 Python 进程(通常是笔记本虚拟机): <jax-client-vm-name>.xplane.pb
  • Pathways IFRT 代理:client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathway 资源管理器:server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathway worker:server.*<tpu-node-name>.xplane.pb

您可以通过运行以下命令,使用 TensorBoard 分析这些轨迹文件。如需详细了解 TensorBoard 及其所有分析工具,请参阅使用 Profiler 优化 TensorFlow 性能

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

替换以下内容:

  • BUCKET:用于存储轨迹文件的 Cloud Storage 存储桶
  • PREFIX:Cloud Storage 存储桶中用于存储轨迹文件的路径

程序化配置文件捕获

从代码内部捕获配置文件。配置文件保存在 gs://<bucket>/<prefix> 内的时间戳目录中

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

手动捕获性能剖析文件

如需手动捕获性能分析信息,您必须从 Python 代码中启动性能分析器服务器:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

在分析器服务器运行期间,您可以捕获配置文件并将数据导出到目标 Cloud Storage 位置:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

您可以在程序的轨迹中找到 IFRT 代理客户端方法(例如 CompileExecute)的时序信息。这些事件详细说明了在编译和执行期间与 IFRT 代理 gRPC 服务器的互动,它们会显示在名为 GrpcClientSessionUserFuturesWorkQueue 的线程上。通过检查轨迹中的此线程,您可以深入了解这些操作的性能。

XLA 标志

使用 Pathways 时,您需要在 pathways-proxy 容器中设置 XLA 标志。您可以使用 XPK 或 PathwaysJob API 来完成此操作。

使用 XPK 时,请设置如下所示的 XLA 标志:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

使用 PathwaysJob API 时,请设置如下所示的 XLA 标志:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

替换以下内容:

  • USER:您的 Google Cloud 用户名
  • value[n]:您要设置的 XLA 标志

HLO 转储

为了深入了解提供给 XLA 编译器的高级优化器 (HLO) 输入,您可以配置 Pathways,以将 HLO 转储到指定的 Cloud Storage 位置,如下所示:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

后续步骤