使用 JAX 在 Cloud TPU 上构建生产 AI

JAX AI 技术栈通过一系列由 Google 支持的可组合库扩展了 JAX 数值核心,使其发展成为一个强大的端到端开源平台,可用于大规模机器学习。因此,JAX AI 技术栈包含一个全面而强大的生态系统,可满足整个机器学习生命周期的需求:

  • 工业规模的基础:JAX AI 技术栈专为大规模应用而设计,利用 ML Pathways 在数万个芯片上编排训练,并利用 Orbax 实现弹性、高吞吐量的异步检查点,从而实现最先进模型的生产级训练。

  • 完整且可用于生产用途的工具包:JAX AI 技术栈为整个开发流程提供了一套全面的库:Flax 用于灵活的模型编写;Optax 用于可组合的优化策略;Grain 用于确定性数据流水线,这对于可重现的大规模运行至关重要。

  • 最高专业性能:为了实现最高的硬件利用率,JAX AI 技术栈提供了专用库,包括用于最先进的自定义内核的 Tokamax、用于提升训练和推理速度的非侵入性量化的 Qwix,以及用于深入的硬件集成性能分析的 XProf。

  • 完整的生产路径:JAX AI 技术栈可实现从研究到部署的无缝过渡。这包括作为基础模型训练的规模化参考框架的 MaxText、用于实现先进的强化学习 (RL) 和校准的 Tunix,以及统一的推理解决方案,该解决方案具有 vLLM TPU 集成和 JAX 服务运行时。

JAX AI 技术栈的理念是使用松散耦合的组件之一,每个组件都擅长一项任务。JAX 本身并非一个单体式机器学习框架,而是范围较窄,专注于高效的数组运算和程序转换。该生态系统基于此核心框架构建,可提供与机器学习模型训练和其他类型工作负载(例如科学计算)相关的各种功能。

这种松散耦合的组件系统可让您以最适合您需求的方式选择和组合库。从软件工程的角度来看,此架构还允许您以迭代方式更新传统上被视为核心框架组件(例如,数据流水线和检查点)的功能,而不会有使核心框架不稳定或进入发布周期的风险。鉴于大多数功能都是在库中实现的,而不是通过更改单体式框架来实现,这使得核心数值库更持久耐用,并且能够适应未来技术格局的变化。

以下部分将从技术角度概述 JAX AI 技术栈、其主要功能、背后的设计决策,以及它们如何结合在一起,为现代机器学习工作负载构建持久的平台。

JAX AI 技术栈和其他生态系统组件

组件 函数/说明
JAX AI 技术栈核心和组件1
JAX 面向加速器的数组计算和程序转换(JIT、grad、vmap、pmap)。
Flax 灵活的神经网络编写库,可直观地创建和修改模型。
Optax 一个包含可组合的梯度处理和优化转换的库。
Orbax “任意规模”分布式检查点库,以实现大规模训练弹性。
粒度 一个规模化、确定性高且可设置检查点的输入数据流水线库。
JAX AI 技术栈 - 基础设施
XLA 适用于 TPU、CPU 和 GPU 的开源机器学习编译器。
Pathways 用于在数万个芯片之间编排计算的分布式运行时。
JAX AI 技术栈 - 高级开发
Pallas 一种 JAX 扩展程序,用于编写以 Python 实现的低级别、高性能自定义内核。
Tokamax 一个精选的库,其中包含最先进的高性能自定义内核(例如,注意力机制)。
Qwix 一个用于量化(PTQ、QAT、QLoRA)的全面且非侵入式库。
JAX AI 技术栈 - 应用
MaxText/MaxDiffusion 用于训练基础模型(例如 LLM 和 Diffusion)的旗舰级规模化参考框架。
Tunix 一个用于实现先进的训练后和校准(RLHF、DPO)的框架。
vLLM 一种高性能 LLM 推理解决方案,使用 vLLM 框架的内置集成。
XProf 一种深度集成硬件的性能分析器,用于进行系统范围的性能分析。

1包含在 jax-ai-stack Python 软件包中。

图 1:JAX AI 技术栈和生态系统组件

JAX AI 技术栈

架构的必要性:超越框架的性能

随着模型架构趋于融合(例如,多模态混合专家 [MoE] Transformer),对峰值性能的追求推动了 Megakernel 兴起。从效果上来说,Megakernel 是一个特定模型(或其中很大一部分)的整个前向传递,使用较低级别的 API(例如 NVIDIA GPU 上的 CUDA SDK)手动编码而成。这种方法通过主动重叠计算、内存和通信来实现最大的硬件利用率。 研究界最近的工作表明,这种方法可以显著提高 GPU 推理的吞吐量,在某些情况下甚至超过 22%。这种趋势不仅限于推理;有证据表明,一些大规模训练工作涉及低级别硬件控制,以实现显著的效率提升。

如果这一趋势加速发展,那么目前所有高级别框架都有可能变得不太相关,因为对于成熟稳定的架构而言,对硬件的低级别访问最终才是决定性能的关键。这给所有现代机器学习技术栈带来了挑战:如何在不牺牲高级框架的效率和灵活性的前提下,提供专家级硬件控制。

为了让 TPU 能够清晰地达到这一性能水平,生态系统必须公开更接近硬件的 API 层,从而实现这些高度专业化内核的开发。JAX 技术栈旨在通过提供从 XLA 编译器的高级自动化优化到 Pallas 内核编写库的精细手动控制的抽象连续体(见图 2)来解决此问题。

图 2:JAX 抽象连续统一体

JAX 抽象连续统一体

核心 JAX AI 技术栈

核心 JAX AI 技术栈包含五个关键库,可为模型开发提供基础:

JAX:可组合的高性能程序转换的基础

JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。借助 JAX 函数式编程模型和类似 NumPy 的 API,JAX 为更高级别的库奠定了坚实的基础。

凭借 JAX“编译器优先”的设计,通过利用 XLA(请参阅 XLA 部分)进行积极的整体程序分析、优化和硬件定位,从而在本质上提升了可伸缩性。 JAX 强调函数式编程(例如纯函数),这使得其核心程序转换更易于处理,并且至关重要的是,可进行组合。

您可以混合搭配使用这些核心转换,以实现高性能并根据模型大小、集群大小和硬件类型来扩缩工作负载:

  • jit:将 Python 函数即时编译为优化的融合 XLA 可执行文件。
  • grad:自动微分,支持前向模式和反向模式,以及高阶导数。
  • vmap:自动向量化,无需修改函数逻辑即可实现无缝批处理和数据并行处理。
  • pmap/shard_map:在多个设备(例如 TPU 核心)上自动并行处理,从而构成分布式训练的基础。

与 XLA 的 GSPMD(通用 SPMD)模型的无缝集成使 JAX 能够自动跨大型 TPU Pod 并行化计算,而只需进行极少的代码更改。在大多数情况下,扩缩只需使用高级别分片注解。

Flax:灵活的神经网络编写

Flax 通过提供直观的面向对象的模型构建方法,简化了 JAX 中神经网络的创建、调试和分析。虽然 JAX 的功能性 API 非常强大,但它为习惯使用 PyTorch 等框架的开发者提供了一种更熟悉的基于层的抽象,而不会带来任何性能损失。

这种设计简化了对训练后的模型组件的修改或组合。 LoRA 和量化等技术需要可操控的模型定义,而 Flax 的 NNX API 通过 Python 式接口提供此类定义。NNX 会封装模型状态,可减少用户认知负载,并允许以程序化方式遍历和修改模型层次结构。

主要优势:

  • 直观的面向对象的 API:简化了模型构建,并支持子模块替换和部分初始化等高级应用场景。
  • 与核心 JAX 保持一致:Flax 提供与 JAX 的函数式范式完全兼容的提升转换,可提供 JAX 的全部性能,同时提升开发者友好度。

Optax:可组合的梯度处理和优化策略

Optax 是一个适用于 JAX 的梯度处理和优化库。它旨在为模型构建者提供可按自定义方式重新组合的基础组件,以便训练深度学习模型以及其他应用。Optax 基于核心 JAX 库的功能构建而成,提供经过充分测试的高性能损失函数和优化器函数库以及可用于训练机器学习模型的相关技术。

动机

损失的计算和最小化是实现机器学习模型训练的核心。JAX 核心库支持自动微分,可提供训练模型所需的数值功能,但它不提供热门优化器(例如 RMSPropAdam)或损失函数(例如 CrossEntropyMSE)的标准实现。虽然您可以实现这些函数(一些高级开发者会选择这样做),但优化器实现中的 bug 会导致难以诊断的模型质量问题。Optax 不会让用户实现此类关键部分,而是提供经过正确性和性能测试的这些算法的实现。

优化理论领域完全属于研究范畴,但它在训练中的核心作用也使其成为训练生产机器学习模型不可或缺的一部分。充当此角色的库既要足够灵活,能够适应快速的研究迭代,又要足够强大且性能出色,能够可靠地用于生产模型训练。它还应提供经过充分测试的先进算法实现,这些实现应与标准方程式相符。Optax 库通过其模块化可组合架构并强调正确可读代码,旨在实现这一目标。

设计

Optax 旨在通过提供可读、经过充分测试且高效的核心算法实现,加快研究速度并促进从研究到生产的过渡。Optax 的用途不仅限于深度学习,不过在此上下文中,它可以被视为一系列以纯函数式方式实现的知名损失函数、优化算法和梯度转换,符合 JAX 的理念。借助一系列众所周知的损失函数优化器,用户可以轻松可靠地开始使用。

Optax 采用模块化方法,可让您将多个优化器链接在一起,然后进行其他常见的转换(例如梯度裁剪),并使用 MultiStep 或 Lookahead 等常见技术封装它们,只需几行代码即可实现强大的优化策略。借助灵活的界面,您可以研究新的优化算法,并使用强大的二阶优化技术,例如 Shampoo 或 Muon。

# Optax implementation of a RMSProp optimizer with a custom learning rate
#  schedule, gradient clipping and gradient accumulation.
optimizer = optax.chain(
  optax.clip_by_global_norm(GRADIENT_CLIP_VALUE),
  optax.rmsprop(learning_rate=optax.cosine_decay_schedule(init_value=lr,decay_steps=decay)),
  optax.apply_every(k=ACCUMULATION_STEPS)
)

# The same thing, in PyTorch
optimizer = optim.RMSprop(model_params, lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_STEPS)
for i, (inputs, targets) in enumerate(data_loader):
    # ... Training loop body ...
    if (i + 1) % ACCUMULATION_STEPS == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VALUE)
        optimizer.step()
        scheduler.step()
 optimizer.zero_grad()

上一个代码段展示了如何设置具有自定义学习速率、梯度裁剪和梯度累积的优化器。

主要优势

  • 强大的库:提供全面的损失、优化器和算法库,重点在于正确性和可读性。
  • 模块化可链接的转换:借助此灵活的 API,您可以以声明方式制定强大的复杂优化策略,而无需修改训练循环。
  • 功能强大且可伸缩:纯函数式实现可与 JAX 的并行处理机制(例如 pmap)无缝集成,让您能够使用相同的代码从单个主机扩容到大型集群。

Orbax/TensorStore - 大规模分布式检查点

Orbax 是一个适用于 JAX 的检查点库,专为从单设备到大规模分布式训练的任何规模而设计。其目标是统一碎片化的检查点实现,并向更广泛的受众群体提供关键的性能功能,例如异步和多层级检查点。Orbax 可实现大规模训练作业所需的弹性,并提供灵活的格式来发布检查点。

与对整个系统状态进行快照的通用检查点和恢复系统不同,使用 Orbax 的机器学习检查点仅选择性地持久保留恢复训练模型权重、优化器状态和数据加载器状态所需的必要信息。这种有针对性的方法可最大限度地减少加速器停机时间。Orbax 通过将 I/O 操作与计算重叠来实现这一点,这对于大型工作负载来说是一项关键功能。时间加速器的空闲时间缩短为设备托管数据传输的时长,这可以进一步与下一个训练步骤重叠,从而使检查点从性能角度来看几乎是免费的。

从根本上讲,Orbax 使用 TensorStore 来高效地并行读取和写入数组数据。Orbax API 可抽象化这种复杂性,提供一个方便用户使用的界面来处理 PyTrees,这是 JAX 中模型的标准表示形式。

主要优势:

  • 广泛采用:Orbax 每月有数百万次下载,是分享机器学习制品的常用媒介。
  • 简化复杂性:Orbax 会消除分布式检查点的复杂性,包括异步保存、原子性和文件系统详细信息。
  • 灵活:Orbax 不仅提供适用于常见应用场景的 API,还允许您自定义工作流以满足特殊要求。
  • 高性能且可伸缩:异步检查点、高效的存储格式 (OCDBT) 和智能数据加载策略等功能可确保 Orbax 能够扩容到涉及数万个节点的训练运行。

Grain:确定性高且可伸缩的输入数据流水线

Grain 是一个 Python 库,用于读取和处理数据,以训练和评估 JAX 模型。Grain 灵活、快速且确定性高,并支持检查点等高级功能,这些功能对于成功训练大型工作负载至关重要。它支持常用的数据格式和存储后端,还提供灵活的 API,可将支持范围扩展到用户特定的格式和原生不受支持的后端。虽然 Grain 主要是为了与 JAX 搭配使用而设计,但它与框架无关,不需要 JAX 即可运行,并且还可以与其他框架搭配使用。

动机

数据流水线是训练基础设施的关键组成部分,需要足够灵活,以便高效表达常见转换,并且性能足够出色,能够始终让加速器保持繁忙状态。 它们还需要能够适应多种存储格式和后端。 由于步长时间较长,大规模训练大型模型对数据流水线提出了常规训练工作负载之外的额外要求,主要围绕确定性和可重现性2。Grain 库采用灵活的架构设计,可满足这些需求。


2PaLM 论文的第 5.1 节中,作者指出,尽管启用了梯度裁剪,但他们仍观察到损失出现了非常大的峰值。解决方案是移除有问题的数据批次,并从损失峰值之前的检查点重新开始训练。只有在完全确定且可重现的训练设置下,才能实现这一点。

设计

从最高层面来看,有两种方式来构建输入流水线:将数据工作器作为单独的集群,或者将数据工作器与驱动加速器的主机放在同一位置。Grain 出于多种原因选择了后者。

加速器与在训练步骤中通常处于闲置状态的强大主机相结合,因此是运行输入数据流水线的自然选择。此实现还有其他优势,它通过在输入和计算之间提供一致的分片视图,简化了数据分片视图。有人可能会认为,将数据工作器放在加速器主机上可能会导致主机 CPU 饱和,但这并不妨碍使用 RPC 将计算密集型转换分流到另一个集群3

在 API 方面,Grain 采用纯 Python 实现,支持多个进程和灵活的 API,让您能够根据易于理解的转换范式将流水线阶段组合在一起,从而实现任意复杂的数据转换。

Grain 开箱即用,支持 ArrayRecordBagz 等高效的随机访问数据格式,以及 ParquetTFDS 等其他常用数据格式。Grain 默认支持从本地文件系统读取数据,以及从 Cloud Storage 读取数据。除了支持常用的存储格式和后端之外,存储层的简洁抽象还可让您添加对现有数据源的支持或封装现有数据源,以使其与 Grain 库兼容。


3这就是多模态数据流水线的运作方式,例如,图片和音频词元化器本身就是模型,它们在自己的集群中、自己的加速器上运行,输入流水线会进行 RPC 调用,以将数据样本转换为词元流。

主要优势

  • 确定性数据馈送:将数据工作器与加速器放在同一位置,并将其与稳定的全局 shuffle 和可设置检查点的迭代器耦合,可使用 Orbax 在一致的快照中将模型状态和数据流水线状态一起设置检查点,从而提高训练过程的确定性。
  • 灵活的 API,可实现强大的数据转换:借助灵活的纯 Python 转换 API,您可以在输入处理流水线中执行广泛的数据转换。
  • 可扩展地支持多种格式和后端:可扩展的数据源 API 支持常用的存储格式和后端,并可让您添加对新格式和后端的支持。
  • 强大的调试界面:借助数据流水线可视化工具和调试模式,您可以检查、调试和优化数据流水线的性能。

扩展的 JAX AI 技术栈

除了核心栈之外,丰富的专用库生态系统还提供了端到端机器学习开发所需的基础设施、高级工具和应用层解决方案。

基本的基础设施:编译器和运行时

XLA:以编译器为中心的硬件无关引擎

动机

XLA(加速线性代数)是 Google 的领域专用编译器,可与 JAX 完美集成,并支持 TPU、CPU 和 GPU 硬件设备。XLA 旨在成为一个与硬件无关的代码生成器,面向 TPU、GPU 和 CPU。

XLA 编译器采用“编译器优先”的设计,这是一项基本架构选择,可在快速发展的研究领域中打造持久的优势。相比之下,其他生态系统中以内核为中心的普遍方法依赖于手动优化的库来实现性能。虽然这种方法对于稳定且成熟的模型架构非常有效,但会成为创新的瓶颈。当新研究引入新颖的架构时,生态系统必须等待编写和优化新内核。不过,我们以编译器为中心的设计通常可以推广到新模式,从而从第一天起就为前沿研究提供高性能途径。

设计

XLA 的工作方式是,对 JAX 在跟踪过程中(例如,当函数使用 @jax.jit 进行修饰时)生成的计算图进行即时 (JIT) 编译。

此编译过程遵循多阶段流水线:

  1. JAX 计算图
  2. 高级优化器 (HLO)
  3. 低级优化器 (LLO)
  4. 硬件代码
  • 从 JAX 图到 HLO:JAX 计算图会转换为 XLA 的 HLO 表示形式。从高层面来看,系统会应用强大的、与硬件无关的优化,例如运算符融合和高效的内存管理。StableHLO 方言可作为此阶段的持久性版本控制接口。
  • 从 HLO 到 LLO:在进行高级别优化后,硬件专用后端会接管,将 HLO 表示形式降为面向机器的 LLO。
  • 从 LLO 到硬件代码:LLO 最终会被编译为高效的机器代码。对于 TPU,此代码捆绑为直接发送到硬件的超长指令字 (VLIW) 数据包。

在扩缩方面,XLA 的设计围绕并行处理构建。它采用算法来最大限度地利用芯片上的矩阵乘法单元 (MXU)。在芯片之间,XLA 使用 SPMD(单程序多数据),这是一种基于编译器的并行处理技术,可在所有设备上使用单个程序。这种强大的模型通过 JAX API 公开,让您能够通过高级分片注解来管理数据、模型或流水线并行性。

对于更复杂的并行模式,还可以使用多程序多数据 (MPMD),并且像 PartIR:MPMD 这样的库也允许 JAX 用户提供 MPMD 注解。

主要优势
  • 编译:计算图的即时编译可优化内存布局、缓冲区分配和内存管理。 基于内核的方法等替代方案将此负担转嫁给了开发者。在大多数情况下,XLA 可以在不影响开发者速度的情况下实现出色的性能。
  • 并行性:XLA 通过 SPMD 实现多种形式的并行性,并在 JAX 级别公开。这样一来,您就可以表达分片策略,从而在数千个芯片上实现模型实验和可伸缩性。

Pathways:一种用于大规模分布式计算的统一运行时

Pathways 为分布式训练和推理提供抽象,并内置容错和恢复功能,使机器学习研究人员能够像使用一台强大的机器一样进行编码。

动机

为了能够训练和部署大型模型,需要数百到数千个芯片。这些芯片分布在多个机架和宿主机上。训练作业是一个大规模的同步程序,需要所有这些芯片及其各自的主机协同工作,对已并行化(分片)的 XLA 计算进行处理。对于可能需要数万个以上芯片的大语言模型,此服务必须能够跨数据中心网络中的多个 Pod,同时还能够在一个 Pod 内使用芯片间互连 (ICI) 和芯片上互连 (OCI) 网络。

设计

ML Pathways 是我们用于协调主机和 TPU 芯片之间的分布式计算的系统。它旨在实现数十万个加速器之间的可伸缩性和效率。对于大规模训练,它为多个 Pod 作业、超大规模 XLA 集成、编译服务和远程 Python 提供了一个 Python 客户端。它还支持跨切片并行处理和抢占容忍,从而能够从资源抢占自动恢复。

Pathways 纳入了优化的跨主机集合,使 XLA 计算图能够扩展到单个 TPU Pod 之外。它扩展了 XLA 对数据、模型和流水线并行处理的支持,通过集成分布式运行时来管理 DCN 通信与 XLA 通信原语,从而跨 TPU 切片边界使用数据中心网络 (DCN)。

主要优势

与 JAX 集成的单控制器架构是一项关键的抽象化技术。借助该架构,研究人员可以探索各种分片和并行处理策略,以便在训练和部署过程中轻松扩容到数万个芯片。

高级开发:性能、数据和效率

Pallas:采用 JAX 编写高性能自定义内核

虽然 JAX 是编译器优先,但在某些情况下,您可能需要对硬件进行精细控制,以实现最佳性能。Pallas 是 JAX 的扩展,可用于为 GPU 和 TPU 编写自定义内核。它旨在精确控制生成的代码,同时兼具 JAX 跟踪和 jax.numpy API 的高级人体工程学特性。

Pallas 采用基于网格的并行处理模型,其中用户定义的内核函数会在多维并行工作组网格中启动。它允许您通过定义张量在较慢但较大的内存(例如 HBM)和较快但较小的芯片上内存(例如 TPU 上的 VMEM、GPU 上的共享内存)之间的平铺和传输方式,明确管理内存层次结构,并使用索引映射将网格位置与特定数据块相关联。 Pallas 可以通过将内核编译为适合目标架构的中间表示形式(适用于 TPU 的 Mosaic)或利用适用于 GPU 的 Triton 等技术,降低相同的内核定义,以便在 Google 的 TPU 和各种 GPU 上高效执行。借助 Pallas,您可以编写专门针对注意力机制等块的高性能内核,从而在目标硬件上实现最佳模型性能,而无需依赖于供应商专用的工具包。

Tokamax:精选的先进内核库

如果 Pallas 是用于编写内核的工具,则 Tokamax 就是一个库,其中包含支持 TPU 和 GPU 的先进自定义加速器内核。Tokamax 基于 JAX 和 Pallas 构建,可让您充分利用硬件的强大功能。它还提供用于构建和自动调优自定义内核的工具。

动机

JAX 源自 XLA,是一个以编译器优先的框架,但在少数情况下,您可能需要直接控制硬件才能实现最佳性能4。自定义内核对于从昂贵的机器学习加速器资源(例如 TPU 和 GPU)获得最佳性能至关重要。虽然它们被广泛用于实现关键运算符(例如注意力机制)的高效执行,但实现它们需要深入了解模型和目标硬件架构。Tokamax 提供一个权威来源,其中包含经过精心挑选、充分测试的高性能内核,并提供强大的共享基础设施来支持这些内核的开发、维护和生命周期管理。此类库还可以作为参考实现,供您根据需要进行构建和自定义。这样一来,您就可以专注于建模工作,而无需担心基础设施问题。


4这是一种成熟的范式,在 CPU 领域已有先例,其中编译后的代码构成了程序的大部分,而开发者会降至内建函数或内嵌汇编来优化对性能至关重要的部分。

设计

对于任何给定的内核,Tokamax 都提供了一个可由多种实现支持的通用 API。例如,TPU 内核可以通过标准 XLA 低级化来实现,也可以通过 Pallas/Mosaic-TPU 明确实现。GPU 内核可以通过标准 XLA 低级化、Mosaic-GPU 或 Triton 实现。默认情况下,Tokamax API 会根据定期自动调优和基准测试运行的缓存结果,为给定的配置选择最知名的实现,不过您也可以根据需要选择特定的实现。随着时间的推移,可能会添加新的实现,以更好地利用新硬件世代中的特定功能,从而进一步提升性能。

除了内核本身之外,Tokamax 库的一个关键组件是支持您编写自定义内核的基础设施。例如,借助自动调优基础设施,您可以定义一组可配置的参数(例如,板块大小),Tokamax 可以对这些参数进行详尽的扫描,以确定并缓存最佳的调优设置。夜间回归可保护您免受因底层编译器基础设施或其他依赖项发生变化而导致的意外性能和数值问题的影响。

主要优势
  • 无缝的开发者体验:统一的精选库提供关键内核的已知良好的高性能实现,并以程序化方式和在文档中清晰表达支持的硬件世代和预期性能。这样可以最大限度地减少碎片化和流失。
  • 灵活性和生命周期管理:您可以选择不同的实现方式,甚至可以在适当的时候更改它们。 例如,如果 XLA 编译器增强了对某些操作的支持,不再需要自定义内核,则可以进行弃用和迁移。
  • 可扩展性:您可以实现自己的内核,同时利用充分受支持的共享基础设施,从而专注于增值功能和优化。清晰编写的标准实现可作为用户学习和扩展的起点。

Qwix:非侵入式、全面的量化

Qwix 是一个适用于 JAX AI 技术栈的全面量化库,支持所有阶段的 LLM 和其他模型类型,包括训练(量化感知训练 [QAT]、量化技术 [QT]、量化低秩自适应 [QLoRA] 和推理训练后量化 [PTQ]),同时面向 XLA 和设备端运行时。

动机

现有的量化库(尤其是在 PyTorch 生态系统中)通常用途有限(例如,仅限 PTQ 或仅限 QLoRA)。这种碎片化的格局迫使您切换工具,从而阻碍了代码的持续使用以及训练和推理之间精确的数值匹配。此外,许多解决方案需要对模型进行大幅修改,从而将模型逻辑与量化逻辑紧密耦合。

设计

Qwix 的设计理念强调提供全面的解决方案,以及至关重要的非侵入式模型集成。它采用分层、可扩展的设计,基于可重复使用的功能性 API 构建。

这种非侵入式集成是通过精心设计的拦截机制实现的,该机制可将 JAX 函数重定向到其量化对等函数。这样一来,您无需进行任何修改即可集成模型,从而将量化代码与模型定义完全分离。

以下示例演示了如何将 w4a4(4 位权重、4 位激活)量化应用于 LLM 的 MLP 层,并将 w8(8 位权重)量化应用于嵌入器。如需更改量化 recipe,您只需更新规则列表。

fp_model = ModelWithoutQuantization(...)
rules = [
    qwix.QuantizationRule(
        module_path=r'embedder',
        weight_qtype='int8',
    ),
    qwix.QuantizationRule(
        module_path=r'layers_\d+/mlp',
        weight_qtype='int4',
        act_qtype='int4',
        tile_size=128,
        weight_calibration_method='rms,7',
    ),
]
quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules))
主要优势
  • 全面的解决方案:Qwix 广泛适用于各种量化场景,可确保训练和推理之间一致的代码使用。
  • 非侵入式模型集成:如示例所示,您只需一行代码即可集成模型。这样一来,您就可以对多种量化方案使用超参数,找到最佳的质量与性能权衡。
  • 与其他库联合:Qwix 可与 JAX AI 技术栈无缝集成。例如,当使用 Qwix 对模型进行量化时,Tokamax 会自动调整为使用量化版本的内核,而无需额外的用户代码。
  • 研究友好:Qwix 的基础 API 和可扩展架构可帮助研究人员探索新算法,并利用集成式基准和评估工具轻松进行比较。

应用层:训练和校准

基础模型训练:MaxText 和 MaxDiffusion

MaxTextMaxDiffusion 分别是 Google 的旗舰级 LLM 和 Diffusion 模型训练框架。这些仓库包含一系列常用开放权重模型的高度优化实现。它们具有双重用途:既可用作现成的模型训练代码库,也可作为基础模型构建者可用来构建的基础模型参考。

动机

整个行业对训练生成式 AI 模型的兴趣正在迅速增长。 开放模型的普及加速了这一趋势,提供了经过验证的架构。训练和调整这些模型需要高性能、高效率、可扩容到大量芯片,以及清晰易懂的代码。MaxText 和 MaxDiffusion 是全面的解决方案,可在 TPU 或 GPU 上使用,旨在满足这些需求。

设计

MaxText 和 MaxDiffusion 是基础模型代码库,在设计时考虑了可读性和性能。它们采用经过充分测试的可重用组件构建而成:使用自定义内核(如 Tokamax)以实现最佳性能的模型定义、用于编排和监控的训练框架,以及一个强大的配置系统,让您可以通过直观的界面控制分片和量化(使用 Qwix)等细节。我们还纳入了多层级检查点等高级可靠性功能,以确保持续的 goodput。

MaxText 和 MaxDiffusion 使用一流的 JAX 库(Qwix、Tunix、Orbax 和 Optax)来提供核心功能。这些库提供稳健且可伸缩的基础设施,可减少开发开销,让您专注于建模任务。对于推理,模型代码是共享的,以实现高效且可伸缩的服务。

主要优势
  • 设计上注重性能:MaxText 和 MaxDiffusion 具有可实现高“goodput”(有用吞吐量)的训练基础设施,以及针对高 MFU(模型 Flops 利用率)进行了优化的模型实现,因此可开箱即用地大规模提供高性能。
  • 专为大规模应用而打造:这些框架利用 JAX AI 技术栈(尤其是 Pathways)的强大功能,让您能够从数十个芯片无缝扩容到数万个芯片。
  • 为基础模型构建者提供坚实的基础:高质量、可读的实现为开发者提供了一个坚实的起点,他们可以将其用作端到端解决方案,也可以将其用作自定义实现的参考实现。

训练后和校准:Tunix 框架

Tunix 提供先进的开源强化学习 (RL) 算法,以及强大的框架和基础设施,为开发者提供简化的途径,以便他们使用 JAX 和 TPU 试验 LLM 训练后技术,包括监督式微调 (SFT) 和校准。

动机

训练后是充分发挥大语言模型真正强大之处的关键步骤。强化学习 (RL) 阶段对于培养校准和推理能力尤为关键。这一领域的开源开发几乎完全基于 PyTorch 和 GPU,这使得 JAX 和 TPU 解决方案存在根本性差距。Tunix (Tune-in-JAX) 是一个高性能的 JAX 原生库,旨在填补这一空白。

设计

Tunix 图

从框架的角度来看,Tunix 支持一种先进的设置,可将 RL 算法与基础设施明确分离。它提供了一个轻量级、类似客户端的 API,可隐藏 RL 基础设施的复杂性,让您开发新算法。Tunix 为热门算法(包括近端策略优化 [PPO]、直接偏好优化 [DPO] 等)提供开箱即用的解决方案。

在基础设施方面,Tunix 与 Pathways 集成,实现了单控制器架构,从而可以进行多节点 RL 训练。在训练方面,Tunix 原生支持参数高效训练(例如 LoRA),并利用 JAX 分片和 XLA(用于机器学习计算图的通用且可伸缩的并行化 [GSPMD])来生成高性能的计算图。它开箱即用,支持 Gemma 和 Llama 等热门开源模型。

主要优势
  • 简单性:Tunix 提供类似于客户端的高级 API,可消除底层分布式基础设施的复杂性。
  • 开发者效率:Tunix 通过内置算法和“recipe”加快研发生命周期,为您提供有效的模型,并让您快速迭代。
  • 性能和可伸缩性:Tunix 利用 Pathways 作为后端上的单个控制器,实现高效且可横向扩缩的训练基础设施。

应用层:生产和推理

在采用 JAX 方面,一直以来都存在一个挑战,那就是如何从研究过渡到生产。JAX AI 技术栈现在提供成熟的双管齐下的生产方案,可同时实现生态系统兼容性和 JAX 性能。

高性能 LLM 推理:vLLM 解决方案

vLLM-TPU 是 Google 的高性能推理栈,旨在在 Cloud TPU 上高效运行 PyTorch 和 JAX 大语言模型 (LLM)。为此,它将热门的开源 vLLM 框架与 Google 的 JAX 和 TPU 生态系统原生集成。

动机

随着对无缝、高性能且易于使用的推理解决方案的需求不断增长,该行业正在快速发展。开发者经常面临复杂且不一致的工具、性能不佳和模型兼容性受限制等重大挑战。vLLM 栈通过提供统一、高性能且直观的平台来解决这些问题。

设计

此解决方案扩展了 vLLM 框架,而不是重塑它。vLLM-TPU 是一款经过高度优化的开源 LLM 服务引擎,以高吞吐量而闻名,它通过使用 PagedAttention(可像管理虚拟内存一样管理 KV 缓存,以最大限度地减少分片)和连续批处理(可动态将请求添加到批处理,以提高利用率)等关键功能来实现高吞吐量。

vLLM-TPU 在此基础上开发了用于处理请求、调度和内存管理的核心组件。它引入了一个充当桥梁的基于 JAX 的后端,可将 vLLM 的计算图和内存操作转换为 TPU 可执行的代码。此后端负责处理设备交互、JAX 模型执行以及在 TPU 硬件上管理 KV 缓存的具体细节。它融入了 TPU 特有的优化,例如高效的注意力机制(例如,利用 JAX Pallas 内核实现 Ragged Paged Attention)和量化,所有这些都针对 TPU 架构量身定制。

主要优势
  • 用户无需支付任何加入/退出费用:用户可以轻松采用此解决方案。从用户体验的角度来看,在 TPU 上处理推理请求与在 GPU 上处理推理请求应该是一样的。用于启动服务器、接受提示和返回输出的 CLI 都是共享的。
  • 完全融入生态系统:此方法利用并贡献于 vLLM 接口和用户体验,确保兼容性和易用性。
  • TPU 和 GPU 之间的可替代性:该解决方案可在 TPU 和 GPU 上高效运行,让您能够灵活选择。
  • 经济高效(最佳性价比):优化性能,为热门模型提供最佳性价比。

JAX 服务:Orbax 序列化和 Neptune 服务引擎

对于 LLM 以外的模型,或者对于希望使用完全 JAX 原生流水线的用户,Orbax 序列化库和 Neptune 服务引擎 (NSE) 系统可提供端到端的高性能服务解决方案。

动机

从历史上看,JAX 模型通常依赖于曲折的生产路径,例如封装在 TensorFlow 图中并使用 TensorFlow Serving 进行部署。 这种方法带来了严重的限制和低效问题,迫使开发者与单独的生态系统互动,并减缓了迭代速度。专用 JAX 原生服务系统对于实现可持续性、降低复杂性和优化性能至关重要。

设计

此解决方案包含两个核心组件,如下图所示。

JAX 服务图

  1. Orbax 序列化库:提供方便用户使用的 API,用于将 JAX 模型序列化为新的、强大的 Orbax 序列化格式。此格式针对生产部署进行了优化。它使用 StableHLO 直接表示 JAX 模型计算,从而能够以原生方式表示计算图。Orbax 序列化库还利用 TensorStore 存储权重,从而实现快速加载检查点以进行服务。
  2. Neptune Serving Engine (NSE):这是一个配套的高性能灵活服务引擎(通常部署为 C++ 二进制文件),旨在以原生方式运行 Orbax 格式的 JAX 模型。NSE 提供生产必需的功能,例如快速模型加载、通过内置批处理实现高吞吐量并发服务、支持多个模型版本,以及单主机和多主机服务(利用 PJRT 和 Pathways)。将 Neptune Serving Engine 用于以下各项:
    • 非 LLM 模型:这是一种通用解决方案,非常适合 Recommender 系统、diffusion 模型和其他 AI 模型等工作负载。
    • 小型 LLM 和“一次性”服务:它专为以“一元”方式提供服务的非自回归模型或较小模型而设计,其中整个输出在一次传递中生成,无需像 KV 缓存那样进行复杂的状态管理。

简而言之,Neptune Serving Engine 填补了以下空白:为各种非大型自回归语言模型提供服务,为更广泛的机器学习生态系统提供高性能的 TPU 原生解决方案。

主要优势
  • JAX 原生服务:该解决方案是为 JAX 原生构建的,可消除模型序列化和部署中的框架间开销。这可确保在 CPU、GPU 和 TPU 上快速加载模型并优化执行。
  • 轻松部署到生产环境:序列化模型提供不受 Python 依赖项漂移影响的封闭式部署路径,并支持运行时模型完整性检查。这为 JAX 模型投入生产提供了顺畅直观的途径。
  • 增强的开发者体验:通过消除对繁琐的框架封装的需求,此解决方案可显著减少依赖项和系统复杂性,从而加快 JAX 开发者的迭代速度。

系统范围分析和性能剖析

XProf:深入的硬件集成性能剖析

XProf 是一种剖析和性能分析工具,可深入了解机器学习工作负载执行的各个方面,让您能够调试和优化性能。XProf 与 JAX 和 TPU 生态系统深度集成。

动机

一方面,机器学习工作负载正变得越来越复杂。另一方面,针对这些工作负载的专用硬件功能正在蓬勃发展。鉴于机器学习基础设施的成本非常高昂,因此有效匹配这两者对于确保最高性能和效率至关重要。这需要深入了解工作负载和硬件,并以可快速使用的方式呈现。XProf 在这方面表现出色。

设计

XProf 由两个主要组件组成:收集和分析。

  1. 收集:XProf 从各种来源捕获信息:JAX 代码中的注解、XLA 编译器中运算的成本模型,以及 TPU 内的专用硬件性能剖析功能。此收集操作可以通过程序化方式触发,也可以按需触发,从而生成全面的事件制品。
  2. 分析:XProf 会对收集的数据进行后处理,并创建一套强大的可视化图表,可通过浏览器访问。
主要优势

XProf 的真正强大之处在于它与全栈的深度集成,可提供广泛而深入的分析,这是共同设计的 JAX/TPU 生态系统的切实优势。

  • 与 TPU 共同设计:XProf 利用专门为无缝性能剖析文件收集而设计的硬件功能,使收集开销低于 1%。这样,性能剖析就可以成为开发过程中轻量级的迭代部分。
  • 分析的广度和深度:XProf 可针对多个维度进行深入分析。其工具包括:
    • Trace Viewer:在不同硬件单元(例如 TensorCore)上执行操作的时间轴视图。
    • HLO Op Profile:将总时间细分为不同的操作类别。
    • Memory Viewer:显示在分析窗口期间,不同操作的内存分配详情。
    • Roofline Analysis:帮助您确定特定操作是受计算限制还是受内存限制,以及它们与硬件的峰值能力相差多少。
    • Graph Viewer:提供由硬件执行的完整 HLO 图的视图。

比较视角:JAX/TPU 栈是一个极具吸引力的选择

现代机器学习领域提供了许多出色的成熟工具链。JAX AI 技术栈为专注于大规模、高性能机器学习的开发者提供了一系列独特而极具吸引力的优势,这些优势直接源于其模块化设计和深度硬件协同设计。

虽然许多框架都提供各种各样的功能,但 JAX AI 技术栈在开发生命周期的关键领域提供了独特而强大的差异化优势:

  • 更简单、更强大的开发者体验:Optax 的可链接梯度转换范式允许使用更强大、更灵活的优化策略,这些策略只会声明一次,而无需在训练循环中以命令式方式进行管理。在系统层面,Pathways 的更简单单控制器接口可消除多切片训练的复杂性,从而为研究人员带来显著的简化。
  • 专为实现大规模弹性而设计:JAX 栈专为极大规模训练而设计。Orbax 提供“超大规模级训练弹性”功能,例如紧急检查点和多层级检查点。Grain 可提供确定性全局 shuffle 和可设置检查点的数据加载器,从而为可重现性提供全面支持。能够以原子方式对数据流水线状态 (Grain) 和模型状态 (Orbax) 进行检查点设置,是保证长时间运行的作业可重现性的关键功能。
  • 完整的端到端生态系统:该栈提供了完整的端到端解决方案。开发者可以使用 MaxText 作为训练的 SOTA 参考,使用 Tunix 进行校准,并通过 vLLM-TPU(用于 vLLM 兼容性)和 NSE(用于 JAX 性能)遵循清晰的生产环境双路径。

虽然从高级软件的角度来看,许多栈都非常相似,但决定性因素往往是性能/TCO,而这正是 JAX 和 TPU 协同设计所带来的明显优势。这种性能/TCO 优势是软件和 TPU 硬件垂直集成的直接结果。XLA 编译器能够专门针对 TPU 架构融合操作,XProf 性能分析器能够使用硬件钩子进行开销低于 1% 的性能分析,这些都是深度集成带来的切实好处。

对于采用此栈的组织,JAX AI 技术栈功能完备的性质可最大限度地降低迁移成本。对于采用热门开放模型架构的客户,从其他框架改用 MaxText 通常只需设置配置文件。此外,该栈能够注入 safetensors 等常用检查点格式,因此可以迁移现有检查点,而无需进行成本高昂的重新训练。

下表列出了 JAX AI 技术栈提供的组件与其在其他框架或库中的等效项的对应关系。

功能 JAX 其他框架中的替代方案/等效方案5
编译器/运行时 XLA Inductor、eager
MultiPod 训练 Pathways Torch lightning 策略、Ray Train、Monarch(新)。
核心框架 JAX PyTorch
模型编写 Flax、Max* 模型 torch.nn.*、NVidia TransformerEngine、HuggingFace Transformer
优化器和损失 Optax torch.optim.*、torch.nn.*Loss
数据加载器 粒度 Ray Data、HuggingFace 数据加载器
检查点功能 Orbax PyTorch 分布式检查点,NeMo 检查点
量化 Qwix TorchAO、bitsandbytes
内核编写和知名实现 Pallas/Tokamax Triton/Helion、Liger-kernel、TransformerEngine
训练后/调优 Tunix VERL、NeMoRL
分析 XProf PyTorch Profiler、NSight Systems、NSight Compute
基础模型训练 MaxText、MaxDiffusion NeMo-Megatron、DeepSpeed、TorchTitan
LLM 推理 vLLM SGLang
非 LLM 推理 NSE Triton 推理服务器、RayServe

5此处的某些等效项并不总是真实的比较,因为与 JAX 相比,其他框架的 API 边界有所不同。等效库的列表并不详尽,并且经常会出现新库。

总结:一个持久耐用、可用于生产用途的平台,可满足 AI 未来需求

上表中的数据表明了一个显而易见的结论:这些栈在少数几个方面各有优缺点,但从软件角度来看,总体上非常相似。这两个栈都为基础模型的预训练、训练后自适应和部署提供了交钥匙解决方案。

JAX AI 技术栈提供了一个极具吸引力且强大的解决方案,可用于训练和部署任意规模的机器学习模型。它利用软件和 TPU 硬件之间的深度垂直集成,提供领先的性能和总拥有成本。

该栈基于经过实战检验的内部系统构建,可提供固有的可靠性和可伸缩性,让用户能够放心地开发和部署最大的模型。其模块化和可组合的设计基于 JAX AI 技术栈理念,可为用户提供卓越的自由度和控制力,让用户能够根据自己的具体需求定制栈,而不会受到单体式框架的限制。

借助 XLA 和 Pathways 提供的可扩容且容错的基础,JAX 提供的性能出色且功能强大的数值库,FlaxOptaxGrainOrbax 等强大的核心开发库,PallasTokamaxQwix 等高级性能工具,以及 MaxText、vLLM 和 NSE 中强大的应用和生产层,JAX AI 技术栈为用户提供了坚实的基础,可供用户在此基础上构建应用,并快速将先进的研究成果投入生产。