通过 JetStream 和 PyTorch 使用 GKE 上的 TPU 应用 LLM

本指南介绍了如何通过 JetStreamPyTorch 使用 Google Kubernetes Engine (GKE) 上的张量处理单元 (TPU) 应用大语言模型 (LLM)。在本指南中,您将模型权重下载到 Cloud Storage,然后使用运行 JetStream 的容器将其部署到 GKE AutopilotStandard 集群上。

如果您在 JetStream 上部署模型时需要利用 Kubernetes 功能提供的可伸缩性、弹性和成本效益,那么本指南是一个很好的起点。

本指南适用于使用 PyTorch 的生成式 AI 客户、GKE 的新用户或现有用户、机器学习工程师、MLOps (DevOps) 工程师或者对使用 Kubernetes 容器编排功能应用 LLM 感兴趣的平台管理员。

背景

通过 JetStream 使用 GKE 上的 TPU 应用 LLM,您可以构建一个可用于生产用途的强大服务解决方案,具备托管式 Kubernetes 的所有优势,包括经济高效、可伸缩性和更高的可用性。本部分介绍本教程中使用的关键技术。

TPU 简介

TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速机器学习和使用 TensorFlowPyTorchJAX 等框架构建的 AI 模型。

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
  2. 了解 GKE 中的 TPU

本教程介绍如何部署各种 LLM 模型。GKE 在单主机 TPUv5e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟提供提示。

JetStream 简介

JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。JetStream 提供高级性能优化(包括连续批处理、KV 缓存优化和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而实现最佳性能。

连续批处理

连续批处理是一种可将传入的推理请求动态分为不同批次,从而缩短延迟时间并提高吞吐量的方法。

KV 缓存量化

KV 缓存量化涉及压缩注意力机制中使用的键值对缓存,从而降低内存要求。

Int8 权重量化

Int8 权重量化可将模型权重的精确率从 32 位浮点数降低到 8 位整数,从而加快计算速度并减少内存用量。

如需详细了解这些优化,请参阅 JetStream PyTorchJetStream MaxText 项目仓库。

PyTorch 简介

PyTorch 是由 Meta 开发的开源机器学习框架,现已成为 Linux 基金会旗下的一部分。PyTorch 提供了张量计算和深度神经网络等高级功能。