将 JAX 工作负载移植到 Pathways

由于 JAX with Pathways 具有分布式特性,因此某些操作可能会因通信开销而无法很好地扩缩。虽然 Pathways 通过异步调度等功能最大限度地减少了这些 开销,但在将 JAX 工作负载移植到 Pathways 或 将 JAX with Pathways 工作负载扩缩到大量加速器时,您需要注意 一些事项。

准备工作

请确保您已备妥:

流程索引

JAX with Pathways 将 Pathways 集群中的所有设备都视为本地设备。 这简化了设备管理,并允许 JAX 利用所有可用资源。实际上,这意味着:

  • 对于所有设备,jax.process_index() 始终为 0。
  • jax.devices()jax.local_devices() 会返回整个作业中的所有 TPU 设备。

硬件类型和共置

为获得最佳性能,请将所有 Pathways 组件和用户作业放置在同一云端区域 Google Cloud 中。使用大型 CPU,例如 IFRT 代理和资源管理器。我们建议至少使用专用 n2-standard-64,它配备 64 个 vCPU 和 256 GB 内存。

PathwaysUtils

Pathways-utils 是一个 基于 Python 的 GitHub 代码库,提供基本实用程序和工具,可让 您简化在 Cloud 上的 Pathways 架构上部署和执行 JAX 工作负载的过程。此软件包会处理云环境所需的必要调整,让 JAX 开发者能够专注于其核心机器学习工作流,而无需进行最少的平台专用配置。具体而言,它提供:

  • “代理”JAX 后端:此自定义后端允许您的 JAX 应用通过设置 JAX_PLATFORMS=proxy 环境变量来使用 Pathways 基础架构。
  • 集成式性能分析实用程序:性能分析功能,可让您了解应用的性能。通过使用标准 JAX 性能分析 API(例如 jax.profiler.start_tracejax.profiler.start_server),您不仅可以分析 JAX 代码,还可以分析底层 Pathways 组件,从而全面了解云环境中的执行情况。
  • 使用 Orbax 进行分布式检查点:自定义 Orbax 检查点处理程序,可让您在使用 Pathways 环境中的 Orbax 库时使用分布式检查点并恢复检查点。此集成旨在在导入 pathwaysutils 的情况下正常运行,而无需对现有的 Orbax 检查点代码进行任何更改。
  • 弹性训练基元:提供基础弹性训练基元,您可以使用这些基元通过 Pathways 构建稳健且可伸缩的训练工作流。 这些基元允许您的训练作业动态适应可用资源的变化,从而提高云环境中的效率和弹性。

检查点

Orbax 已通过 Pathways 的全面测试,可用于通过 Cloud Storage 进行 分布式检查点和恢复。当您 调用 import pathwaysutils; pathwaysutils.initialize() 时,系统会注册自定义 ArrayHandler,该处理程序通过 IFRT 代理高效处理检查点操作 ,从而允许加速器上的 Pathways 工作器直接保存和恢复数据。train.py

共置 Python

共置 Python 是一种开源 JAX API,可让您直接在 TPU 或 GPU 主机上运行用户指定的 Python 代码, 这在多控制器 JAX 中更加简单。这样一来,计算密集型任务(例如数据加载和检查点)就可以避免在客户端和 TPU 机器之间传输数据。如需将 Pathways 集群配置为运行共置 Python JAX API,请按照 共置 Python README中的 说明进行操作。这些说明介绍了如何与 Pathways 工作器一起启动共置 Python Sidecar。

数据加载

在训练期间,我们会反复从数据集中加载批量数据,以馈送给模型。为了避免加速器缺少工作,请务必使用高效的异步数据加载器,将批次分片到多个主机。使用 Pathways 运行训练时,数据加载器会在 CPU 虚拟机上运行(与在多控制器设置中使用的 TPU 虚拟机不同),并将数据调度到 TPU 虚拟机。这会导致读取数据时的延迟较高,但通过在 CPU 主机上提前读取 X 个批次并将读取的数据异步调度到 TPU,可以部分缓解这种情况。在小规模到中等规模运行的情况下,此解决方案已足够。

如需以大规模获得最佳性能,我们强烈建议您通过使用 共置 Python 直接在加速器上运行数据流水线,从而共置输入 数据流水线。这样可以消除 CPU 瓶颈,并利用 TPU 的快速互连进行数据传输。

您可以在 RemoteIterator 实现中找到迁移基于 TFDS 的 输入流水线的参考实现,该实现位于 multihost_dataloading.py中。 此实现使用共置 Python JAX API 以分布式方式在多控制器 JAX 和 Pathways 上运行。

Jax 版本控制

Pathways 版本与 JAX 版本紧密相关,以确保兼容性和稳定性。为避免潜在问题,请验证您的 Pathways 工件和 JAX 版本是否一致。每个 Pathways 版本都通过 jax-<version> 形式的标记清楚地指定了 兼容的 JAX 版本。

编译缓存

Pathways 永久性编译缓存是一项功能,允许 Pathways 服务器将已编译的 XLA 可执行文件存储在永久性位置(例如 Cloud Storage)中,以避免冗余编译。此功能默认处于开启状态。缓存的位置作为 --gcs_scratch_location 标志传递给资源管理器和 Pathways 工作器容器。为尽可能降低相关存储费用,缓存会将生命周期政策附加到 Cloud Storage 位置。每个 Cloud Storage 存储桶最多只能有 50 个政策。因此,我们建议在所有工作负载中使用通用的 Cloud Storage 位置。

此缓存类似于 JAX 编译缓存 ,后者由 pathwaysutils.initialize() 针对 Pathways 工作负载停用。

编译缓存需要以下 Cloud Storage 权限:

  • storage.buckets.get:用于检索存储桶元数据。
  • storage.buckets.update:对于 Pathways 设置对象生命周期政策以强制执行缓存逐出 TTL 至关重要。
  • storage.objects.list:用于列出存储桶中的现有缓存对象。
  • storage.objects.create:用于将新的已编译可执行文件写入缓存。
  • storage.objects.get:用于从存储桶读取已缓存的可执行文件。

性能分析

您可以使用 JAX 性能分析器生成 JAX 程序的跟踪记录。Pathways 支持两种常见方法:

  • 程序化
    • 以编程方式从 JAX 代码捕获性能分析文件
  • 手动
    • 从 JAX 代码启动性能分析器服务器后,按需捕获性能分析文件

在这两种情况下,性能分析文件都会写入 Cloud Storage 存储桶。 Cloud Storage 存储桶中可能会创建多个跟踪记录文件,这些文件可能位于不同的时间戳文件夹下,例如:

  • 调用跟踪记录的主 Python 进程(通常是笔记本虚拟机): <jax-client-vm-name>.xplane.pb
  • Pathways IFRT 代理:client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways 资源管理器:server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways 工作器:server.*<tpu-node-name>.xplane.pb

您可以使用 TensorBoard 通过运行以下命令来分析这些跟踪记录文件。如需详细了解 TensorBoard 及其所有性能分析工具, 请参阅 使用 Profiler 优化 TensorFlow 性能

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

替换以下内容:

  • BUCKET:用于存储跟踪记录文件的 Cloud Storage 存储桶
  • PREFIX:Cloud Storage 存储桶中用于存储跟踪记录文件的路径

程序化性能分析文件捕获

从代码内部捕获性能分析文件。性能分析文件保存在 gs://<bucket>/<prefix>下的时间戳目录中

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

手动性能分析文件捕获

如需手动捕获性能分析信息,您必须从 Python 代码启动性能分析器服务器:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op

在性能分析器服务器运行时,您可以捕获性能分析文件并将数据导出到目标 Cloud Storage 位置:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

您可以在程序的跟踪记录中找到 IFRT 代理客户端方法(例如 CompileExecute)的时序信息。这些事件详细说明了在编译和执行期间与 IFRT 代理 gRPC 服务器的交互,并显示在名为 GrpcClientSessionUserFuturesWorkQueue 的线程上。通过检查跟踪记录中的此线程,您可以深入了解这些操作的性能。

XLA 标志

使用 Pathways 时,您需要在 pathways-proxy 容器中设置 XLA 标志。您可以使用 XPK 或 PathwaysJob API 来执行此操作。

使用 XPK 时,请设置如下所示的 XLA 标志:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

使用 PathwaysJob API 时,请设置如下所示的 XLA 标志:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

替换以下内容:

  • USER:您的 Google Cloud 用户名
  • value[n]:您要设置的 XLA 标志

HLO 转储

如需深入研究提供给 XLA 编译器的高级优化器 (HLO) 输入,您可以将 Pathways 配置为将 HLO 转储到指定的 Cloud Storage 位置,如下所示:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

后续步骤