本教程介绍如何在 Google Kubernetes Engine (GKE) 上使用张量处理单元 (TPU) 来提供 Gemma 大语言模型 (LLM)。您会将包含 JetStream 和 MaxText 的预构建容器部署到 GKE。您还需要配置 GKE,以便在运行时从 Cloud Storage 加载 Gemma 7B 权重。
本教程适用于机器学习 (ML) 工程师、平台管理员和运维人员,以及对使用 Kubernetes 容器编排功能提供 LLM 感兴趣的数据和 AI 专家。如需详细了解我们在Google Cloud 内容中提及的常见角色和示例任务,请参阅常见的 GKE 用户角色和任务。
在阅读本页面之前,请确保您熟悉以下内容:
- Autopilot 模式和 Standard 模式
- Cloud TPU 系统架构中的当前 TPU 版本可用性
- GKE 中的 TPU
背景
本部分介绍本教程中使用的关键技术。
Gemma
Gemma 是一组公开提供的轻量级生成式人工智能 (AI) 模型(根据开放许可发布)。这些 AI 模型可以在应用、硬件、移动设备或托管服务中运行。您可以使用 Gemma 模型生成文本,但也可以针对专门任务对这些模型进行调优。
如需了解详情,请参阅 Gemma 文档。
TPU
TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速机器学习和使用 TensorFlow、PyTorch 和 JAX 等框架构建的 AI 模型。
本教程介绍如何应用 Gemma 7B 模型。GKE 在单主机 TPUv5e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟提供提示。
JetStream
JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。它提供高级性能优化(包括连续批处理和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而实现最佳性能。
如需详细了解这些优化,请参阅 JetStream PyTorch 和 JetStream MaxText 项目仓库。
MaxText
MaxText是一个高性能、可扩缩且适应性强的 JAX LLM 实现,基于如下开源 JAX 仓库构建:Flax、Orbax 和 Optax。MaxText 的仅解码器 LLM 实现是使用 Python 编写的。它大量利用 XLA 编译器来实现高性能,而无需构建自定义内核。
如需详细了解 MaxText 支持的最新模型和参数大小,请参阅 MaxtText 项目仓库。