在 TPU 上扩缩模型

本文档概述了如何扩缩语言模型:TPU 的工作方式和彼此之间的通信方式、LLM 在真实硬件上的运行方式,以及如何在训练和推理期间并行运行模型,以便它们能够以大规模高效运行。我们提供的信息可帮助您评估训练 LLM 的费用、服务模型所需的内存量,以及如何有效地在多个 TPU 上对模型进行分片。

虽然许多深度学习都很复杂,但优化模型的性能(即使是大规模优化)也不一定很复杂。基本原则适用于任何地方(从处理单个加速器到数万个加速器),了解这些原则可让您执行许多有用的操作:

  • 估算模型各部分与理论最佳状态的接近程度。
  • 根据不同规模做出有关不同并行方案的明智选择(如何在多个设备之间拆分计算)。
  • 估算训练和运行大型 Transformer 模型所需的费用和时间。
  • 设计可利用 TPU 架构的算法。
  • 设计模型架构,明确了解哪些因素会限制算法性能。

前提条件

您应该对 LLM 和 Transformer 架构有基本的了解,但不一定了解它们如何大规模运行。您应该了解 LLM 训练的基础知识,最好对 JAX 有一些基本的了解。有关 Transformer 架构的实用背景阅读材料包括:

熟悉这些前提条件后,您应该能够轻松估算给定 TPU 平台上 Transformer 模型的最佳并行方案。您还可以估算训练和推理所需的时间。

模型伸缩的重要性

如今,LLM 和大多数小型模型都非常接近硬件限制,因此在开发模型时,您需要考虑大规模效率。如果基准测试的 20% 优势是以 20% 的屋顶线效率为代价,那么这种优势就无关紧要了。有前景的模型架构通常会失败,要么是因为它们无法大规模高效运行,要么是因为缺乏优化工作来使其能够高效运行。

模型伸缩的目标是能够增加用于训练或推理的芯片数量,同时实现吞吐量的成比例线性增长。这称为强伸缩。虽然添加额外的芯片(并行)通常会缩短计算时间,但也会增加芯片之间的通信成本。当通信时间长于计算时间时,模型会受到通信限制,无法很好地扩缩。 充分了解硬件以预测这些瓶颈的出现位置,可让您设计或重新配置模型以避免这些瓶颈。

以下部分概述了如何扩缩 TPU 硬件以及 Transformer 架构的演变过程。这些信息对于设计新架构的研究人员和致力于使当前一代大语言模型快速运行的工程师都很有用。

第 1 部分:概念

本部分介绍了屋顶线分析以及限制模型扩缩能力的因素(通信、计算和内存)。接下来,我们将介绍 TPU 的工作方式,包括作为单个芯片和(至关重要)作为互连系统(芯片间链路的带宽和延迟有限)。

  • 屋顶线 分析简介:本部分 介绍了如何根据 计算、通信和内存限制来估算算法的运行速度。
  • TPU 架构上的运算:本部分 介绍了 TPU 的架构、TPU 中不同硬件模块的工作方式,以及它们如何影响模型训练和服务。
  • 用于多 TPU 并行的模型分片:本部分通过解释分片 矩阵乘法深入探讨了模型分片和多 TPU 并行。

第 2 部分:扩缩 Transformer

务必了解 Transformer 架构的每个部分:每个矩阵的确切大小、规范化发生的位置,以及每个部分中的参数和 FLOP 数量。本部分仔细介绍了 Transformer 数学,展示了如何计算训练和推理的参数和 FLOP。这会告诉您模型将使用的内存量、您将花费在计算或通信上的时间,以及相对于前馈块,注意力何时变得重要。

最后,本部分可帮助您回答一个基本问题:给定特定大小的模型并提供一些芯片,如何并行化模型以保持强伸缩条件。为了回答这个问题,本部分讨论了用于在多个芯片上拆分模型的四种主要并行技术:数据、张量、流水线和专家。它还介绍了其他减少内存需求的技术,例如重新物化、ZeRO 驱动的模型分片、主机卸载和梯度累积。

  • Transformer 数学 运算简介:本 部分通过数学运算回答了有关 Transformer 在正向和反向传递期间使用的 FLOP 数量、计算参数数量的计算以及 KV 缓存大小的问题。
  • 用于训练的 Transformer 并行化:本部分 详细介绍了通过协调 FSDP、 Megatron 分片和流水线并行来最大限度提高训练效率的过程。它介绍了如何确定特定模型大小和批次大小在固定数量的芯片上的最佳分布,以实现峰值吞吐量。
    • 在 TPU上训练 Llama 3:本 小节介绍了如何在 TPU 上训练 Llama 3、可能需要多长时间以及可能需要多少费用。
  • 用于推理的 Transformer 伸缩:训练模型 后,需要提供服务。推理增加了一个新的考虑因素,即延迟时间,并改变了内存格局。本部分介绍了如何进行分离式服务以及如何考虑 KV 缓存。
    • 在 TPU上服务 Llama 3:本 小节介绍了如何在 TPU 上服务 Llama 3、可能需要多少 费用以及延迟时间和吞吐量之间的权衡。

第 3 部分:实际实现

本部分介绍了如何使用 JAX 实现伸缩概念,以及在出现问题时如何分析和调试代码。

  • 分析 TPU 程序: 实际 LLM 非常复杂,难以开发、优化和调试。 本部分介绍了 JAX + XLA 堆栈,以及如何使用 JAX/TensorBoard 分析器调试和修复实际问题。
  • 在 JAX 中对 TPU 进行编程: 本部分介绍了如何使用 JAX API 并行化计算。