本指南介绍了如何跨多个节点使用张量处理单元 (TPU),在 Google Kubernetes Engine (GKE) 上提供先进的大语言模型 (LLM),例如 Llama 3.1 405B。
本指南演示了如何使用可移植的开源技术(Kubernetes、JetStream、Pathways on Cloud 和 LeaderWorkerSet (LWS) API)在 GKE 上部署和提供 AI/机器学习工作负载,并利用 GKE 的精细控制、可伸缩性、弹性、可移植性和成本效益。
背景
大语言模型的规模不断扩大,已无法在单个主机 TPU 切片上运行。对于机器学习推理,您可以使用 Pathways on Cloud 在 GKE 上跨多个互连的 TPU 节点运行大规模多主机推理。在本指南中,您将逐步了解如何预配具有多主机 TPU 切片的 GKE 集群,使用 Pathways on Cloud 二进制文件,通过 MaxText 框架启动 JetStream 服务器,以及发出多主机推理请求。
通过 JetStream、MaxText 和 Pathways 使用 GKE 上的 TPU 提供 LLM,您可以构建一个可用于生产用途的强大服务解决方案,具备托管式 Kubernetes 的所有优势,包括经济高效、可伸缩性和更高的可用性。本部分介绍本教程中使用的关键技术。
TPU 简介
TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速机器学习和使用 TensorFlow、PyTorch 和 JAX 等框架构建的 AI 模型。
使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:
- 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
- 了解 GKE 中的 TPU。
本教程介绍如何提供 Llama 3.1-405B 模型。GKE 在多主机 TPU v6e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟响应提示。
Pathways on Cloud
Pathways 是一个适用于加速器的大规模编排层。Pathways 经过精心设计,可用于探索新的系统和机器学习研究理念,同时保持当前模型的出色性能。Pathways 使单个 JAX 客户端进程能够协调一个或多个大型 TPU 切片之间的计算,从而简化跨数百或数千个 TPU 芯片的机器学习计算。
JetStream
JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。JetStream 提供高级性能优化(包括连续批处理、KV 缓存优化和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而优化性能。
MaxText
MaxText是一个高性能、可扩缩且适应性强的 JAX LLM 实现,基于如下开源 JAX 仓库构建:Flax、Orbax 和 Optax。MaxText 的仅解码器 LLM 实现是使用 Python 编写的。它大量利用 XLA 编译器来实现高性能,而无需构建自定义内核。
如需详细了解 MaxText 支持的最新模型和参数大小,请参阅 MaxText 项目仓库。
Llama 3.1 405B
Llama 3.1 405B 是由 Meta 提供的大语言模型,专为各种自然语言处理任务(包括文本生成、翻译和问答)而设计。GKE 提供所需的基础设施,以支持这种规模的模型的分布式训练和服务需求。
如需了解详情,请参阅 Llama 文档。
架构
本部分介绍本教程中使用的 GKE 架构。该架构包括一个 GKE Standard 集群,该集群用于预配 TPU 并托管 JetStream 和 Pathways 组件以部署和提供模型。
下图展示了此架构的组件:
此架构包括以下组件:
- GKE Standard 区域级集群。
- 一个多主机 TPU 切片节点池,用于托管 JetStream 部署和 Pathways 组件。
Pathways resource manager管理加速器资源,并协调用户作业的加速器分配。Pathways client与Pathways resource manager协同工作,以确定编译后的程序放置在何处以供执行。Pathways worker在加速器机器上运行并执行计算,然后通过 IFRT 代理服务器将数据发送回工作负载。IFRT proxy client实现了 OSS 临时框架运行时 (IFRT) API,并充当工作负载与 Pathways 组件之间的通信桥梁。IFRT proxy server从IFRT proxy client接收请求并将其转发给Pathways client,从而分配工作。JetStream-Pathways容器提供了一个基于 JAX 的推理服务器,该服务器接收推理请求并将其执行过程委托给Pathways workers- Service 组件将入站流量分布到所有
JetStream HTTP副本。 JetStream HTTP是一个 HTTP 服务器,它接受封装容器形式的 JetStream 所需格式的请求并将其发送到 JetStream 的 GRPC 客户端。