本教程介绍如何使用 MaxText、Ray Train 和多切片 Trillium TPU 在 Google Kubernetes Engine (GKE) 上训练 Llama 3 70B 等大语言模型 (LLM)。本教程提供了一个完整的端到端演练,从配置必要的辅助数据中心网络到提交并成功运行跨 32 个物理 TPU 芯片的分布式训练工作负载。
本教程适用于平台管理员、运维人员和 AI 专家,他们希望了解如何克服在分布式多主机 TPU 切片上训练 700 亿参数模型时遇到的内存和网络问题。
背景
GKE、KubeRay、MaxText 和 TPU 的组合为大规模模型训练提供了一个强大且可伸缩的平台。本部分介绍本指南中使用的关键技术。
JAX
JAX 是一个面向加速器的数组计算和程序转换 Python 库,它利用 XLA 编译器创建高度优化的代码,可在加速器上高效扩展。
MaxText
MaxText 是一种高性能的开源 LLM 框架,旨在实现可伸缩性和可自定义性。MaxText 基于 JAX 构建,并经过优化,可在 Cloud TPU 上高效运行。
TPU
张量处理单元 (TPU) 是 Google 定制设计的加速器,旨在优化机器学习工作负载。与通用 CPU 或并行处理 GPU 不同,TPU 专为深度学习基础中的大规模矩阵和张量计算而设计,因此能够高效完成此特定任务。TPU 的主要优势在于大规模性能。
本教程采用多切片部署模式,使用第六代 TPU TPU Trillium。Cloud TPU 多切片是指两个或更多 Cloud TPU 切片通过数据中心网络 (DCN) 进行通信。多切片支持全栈、经济高效的大规模的训练,具有近线性伸缩能力,可达到数万个 TPU 芯片。如需详细了解 Multislice,请参阅 Cloud TPU Multislice 概览。
KubeRay
KubeRay 是一个 Kubernetes 操作器,可提供一种在 Kubernetes 上部署、管理和监控 Ray 应用的统一方式。KubeRay 操作器通过 Ray on GKE 插件进行安装和管理,这是在 GKE 上部署和管理 Ray 集群的推荐方法。
GKE 动态资源分配网络 (DRANET)
GKE DRANET(动态资源分配网络)是一项可将高性能网络设备动态附加到 Pod 的功能,可绕过标准 Kubernetes 网络,从而在 DCN 上实现高性能。
目标
本教程介绍了如何执行以下操作:
- 设置具有两个多主机 TPU 节点池的 GKE 集群。
- 为跨切片 TPU 通信配置辅助 DCN。
- 配置 KubeRay 以管理分布式训练环境。
- 通过使用动态资源分配 (DRA) 功能来部署 RayCluster 自定义资源,以实现网络连接。
- 利用 Ray Train 的 JaxTrainer 创建一个 Python 训练脚本,以在 TPU 切片中编排 MaxText 训练循环。
- 运行基准 Llama 3 8B 训练作业。
- 通过 DCN 利用 2D 分片(张量并行处理和 FSDP)将模型扩容到 Llama 3 70B。
准备工作
- 登录您的 Google Cloud 账号。如果您是 Google Cloud新手,请 创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
-
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
-
创建项目:如需创建项目,您需要拥有 Project Creator 角色 (
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
启用所需的 API:
启用 API 所需的角色
如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予角色。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
安装 Google Cloud CLI。
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
选择或创建项目所需的角色
- 选择项目:选择项目不需要特定的 IAM 角色,您可以选择已获授角色的任何项目。
-
创建项目:如需创建项目,您需要拥有 Project Creator 角色 (
roles/resourcemanager.projectCreator),该角色包含resourcemanager.projects.create权限。了解如何授予角色。
-
创建 Google Cloud 项目:
gcloud projects create PROJECT_ID
将
PROJECT_ID替换为您要创建的 Google Cloud 项目的名称。 -
选择您创建的 Google Cloud 项目:
gcloud config set project PROJECT_ID
将
PROJECT_ID替换为您的 Google Cloud 项目名称。
启用所需的 API:
启用 API 所需的角色
如需启用 API,您需要拥有 Service Usage Admin IAM 角色 (
roles/serviceusage.serviceUsageAdmin),该角色包含serviceusage.services.enable权限。了解如何授予角色。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
向您的用户账号授予角色。对以下每个 IAM 角色运行以下命令一次:
roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editorgcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
替换以下内容:
PROJECT_ID:您的项目 ID。USER_IDENTIFIER:您的用户 账号的标识符。例如,myemail@example.com。ROLE:您授予用户账号的 IAM 角色。
- 由于本教程使用 TPU Trillium (v6e),请选择可用的区域或可用区。如需了解详情,请参阅 Cloud TPU 配额。
准备环境
在本教程中,您将使用 Cloud Shell。gcloudCloud Shellhelm 预安装了本教程中使用的 kubectl、 和 命令行工具。
前往 Google Cloud 控制台。
在 Google Cloud 控制台窗口顶部,点击激活 Cloud Shell
按钮。一个 Cloud Shell 会话随即会在Google Cloud 控制台中的新框架内打开,并显示命令行提示符。
在终端中,克隆
kubernetes-engine-samples代码库:git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git切换到包含示例文件的目录:
cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext创建并激活 Python 虚拟环境:
python3 -m venv ray-env source ray-env/bin/activate安装 Ray CLI:
pip install "ray[default]==2.55.0"设置以下环境变量:
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export CLUSTER_VERSION=1.35.2-gke.1842000替换以下内容:
GS_BUCKET:Cloud Storage 存储桶的名称。KSA_NAME:Kubernetes ServiceAccount 的名称。CLUSTER_NAME:新集群的名称。REGION:您的 TPU Trillium 容量可用的区域。ZONE:TPU Trillium 容量可用的可用区。如需了解详情,请参阅 GKE 中的 TPU 可用性。
为 Cloud TPU Multislice 配置集群网络
在多主机 TPU 切片中,TPU 设备通过高速芯片间互连进行通信。不过,在运行多切片作业时,TPU 切片必须通过 DCN 相互通信。标准 Kubernetes Pod 网络可能会导致此流量出现瓶颈。ct6e-standard-4t 机器类型由多个物理网络接口卡 (NIC) 提供支持。为实现最佳性能,您可以创建两个额外的 VPC 网络,并使用 GKE DRANET 将它们直接连接到 Ray Pod。
创建两个具有较大最大训练单元 (MTU) 的额外 VPC 网络:
gcloud compute networks create ${CLUSTER_NAME}-net-1 \ --subnet-mode=custom \ --mtu=8896 gcloud compute networks create ${CLUSTER_NAME}-net-2 \ --subnet-mode=custom \ --mtu=8896创建专用子网:
gcloud compute networks subnets create tpu-subnet-1 \ --network=${CLUSTER_NAME}-net-1 \ --region=${REGION} \ --range=10.50.0.0/16 gcloud compute networks subnets create tpu-subnet-2 \ --network=${CLUSTER_NAME}-net-2 \ --region=${REGION} \ --range=10.60.0.0/16
创建 GKE 集群
您可以在 GKE Autopilot 或 Standard 集群中的 TPU 上配置 KubeRay。我们建议您使用 Autopilot 集群获得全托管式 Kubernetes 体验。如需选择最适合您的工作负载的 GKE 操作模式,请参阅 GKE 操作模式简介。
如需使用 GKE 管理的 DRANET,您的集群必须使用 1.35.2-gke.1842000 或更高版本(对于 Autopilot 模式),或者 1.34.1-gke.1829001 或更高版本(对于标准模式)。本教程使用版本 1.35.2-gke.1842000。
Autopilot
在 Cloud Shell 中,运行以下命令:
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION \ --cluster-version=${CLUSTER_VERSION}如需与集群通信,请配置
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$REGION
标准
在 Cloud Shell 中,运行以下命令以创建启用 Ray operator 插件的 Standard 集群:
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator,GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --enable-dataplane-v2 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE \ --cluster-version=${CLUSTER_VERSION}此命令还会启用
GcsFuseCsiDriver,从而允许 Pod 将 Cloud Storage 存储分区作为本地文件系统进行装载。集群创建可能需要几分钟的时间。如需与集群通信,请配置
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE创建第一个启用了 GKE DRANET 的多主机 TPU 切片节点池:
gcloud container node-pools create v6e-16-0 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform创建第二个 TPU 切片节点池:
gcloud container node-pools create v6e-16-1 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform
GKE 会预配一个由四个 TPU Trillium (v6e) 虚拟机组成的节点池,这些虚拟机共同配置为一个具有 4x4 拓扑的多主机 TPU 切片。此节点池已准备好用于分布式训练工作负载。
启用了 Ray 操作器的 GKE 集群会自动在集群中安装 KubeRay 和 KubeRay TPU Webhook。
配置 Cloud Storage 存储桶和服务账号
创建一个 Cloud Storage 存储桶,用于在多主机 TPU 节点之间共享检查点。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}如需启用对 Cloud Storage 存储桶的访问权限,请创建 Kubernetes ServiceAccount:
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}如需启用对 Cloud Storage 存储桶的访问权限,请向服务账号添加所需的 IAM 政策绑定:
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
创建训练脚本
maxtext_multi_slice_trainer.py 脚本使用 Ray Train 的 JaxTrainer 在两个 TPU 切片上运行分布式 MaxText 训练作业。该脚本为八个多主机 TPU 工作器配置训练环境,并在每个工作器节点上运行 MaxText 训练作业。train_loop_per_worker 函数封装了 MaxText 主要入口点,并使用 Ray 的分布式调度程序在多主机 TPU 切片上执行 MaxText 训练器:
上述脚本定义了一个 JaxTrainer 实例,该实例请求 8 个 worker 和 4x4 的拓扑。在内部,Ray 会在两个 TPU 切片之间预配 SlicePlacementGroup,并帮助确保 Ray Train 工作器在两个切片上以原子方式运行,每个主机上有一个工作器。
训练模型
当前目录中的
ray-cluster.tpu-multi-slice.yaml清单定义了 RayCluster 自定义资源。此清单包含 DRANETResourceClaimTemplate,用于为 GKE DRANET 和多切片配置网络设备:上述 RayCluster 规范会创建一个 TPU 工作进程组,其中每个副本有 8 个工作进程 (
numOfHosts: 4),并且有两个副本。每个工作器请求 4 个 TPU 芯片 (google.com/tpu: "4")。每个工作器都调度在同一共置多主机切片中的 TPU Trillium 节点 (tpu-v6e-slice) 上。KubeRay 会以原子方式扩缩切片中的所有四个工作器。所需的 JAX 环境变量以及用于调度的 Pod 亲和性由 GKE 通过变更 webhook 进行引导。如需创建 RayCluster,请应用清单:
envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -验证集群是否已准备就绪并正在运行:
kubectl get rayclusters maxtext-tpu-cluster输出应类似如下所示:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 8 8 72 1579277216Ki 0 ready 2m11s如需通过 Ray 头服务访问 Ray 信息中心,请建立端口转发会话:
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &验证 RayCluster 是否可从本地环境访问:
ray list nodes --address http://localhost:8265输出应类似如下所示:
ray list nodes --address http://localhost:8265 2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. 2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads. ======== List: 2026-04-21 10:20:20.945431 ======== Stats: ------------------------------ Total: 9 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1 10.68.9.5 False ALIVE 10.68.9.5 CPU: 8.0 ray.io/accelerator-type: TPU-V6E TPU: 4.0 ray.io/node-group: tpu-group accelerator_type:TPU-V6E: 1.0 ray.io/node-id: 4f0e4d742... memory: 186.265 GiB ray.io/tpu-pod-type: v6e-16 node:10.68.9.5: 1.0 ray.io/tpu-slice-name: tpu-group-0 object_store_memory: 186.265 GiB ray.io/tpu-topology: 4x4 tpu-group-0: 1.0 ray.io/tpu-worker-id: '1' ... 6 ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938 10.68.2.9 True ALIVE 10.68.2.9 CPU: 8.0 ray.io/node-group: headgroup memory: 16.000 GiB ray.io/node-id: ce7056807... node:10.68.2.9: 1.0 node:__internal_head__: 1.0 object_store_memory: 4.765 GiB ...下载基本 MaxText 配置文件。训练脚本需要此文件来设置模型的默认超参数:
curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml将 JaxTrainer 脚本提交到 RayCluster,并检查 RayJob 是否成功完成:
Llama 3 8B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=4 \
max_target_length=4096 \
model_name=llama3-8b \
steps=100 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=4 \
run_name=rayjob-multi-slice
Llama 3 70B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=2 \
max_target_length=4096 \
model_name=llama3-70b \
steps=100 \
ici_tensor_parallelism=4 \
ici_fsdp_parallelism=4 \
dcn_fsdp_parallelism=2 \
dcn_data_parallelism=1 \
remat_policy=full \
run_name=rayjob-multi-slice-70b-fsdp
上述命令会提交 Python 脚本,该脚本会调用 JaxTrainer Ray 代码到 RayCluster。ray job submit 命令包含一些特定于 MaxText 的实参,用于传递给模型配置。
在终端中,您应该会看到类似以下内容的 Llama 3 70B 作业输出:
[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]
------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------
在 Spot 虚拟机上运行多切片弹性训练
使用 TPU 等热门加速器时,利用 Spot 虚拟机可能会大幅降低费用。不过,Spot 虚拟机可能会被意外抢占。
Ray Train 支持弹性训练,可让作业动态扩缩参与的 TPU 切片数量,而不会失败。如果某个 slice 被抢占,Ray 会暂停训练循环,等待剩余的工作器重新组织,从最新的 MaxText 检查点恢复,然后在较小的占用空间上恢复训练。
如需启用弹性训练,请将 ScalingConfig 中的 num_workers 参数从静态整数更改为表示 (minimum_workers, maximum_workers) 的元组。此外,向 RunConfig 添加 FailureConfig(max_failures=3),指示 Ray Train 在工作器被抢占时重试训练循环最多 3 次,而不是让作业完全失败。
更新 Ray Train 脚本
当前目录中的
maxtext_elastic_trainer.py脚本可实现弹性训练。请注意,它设置了num_workers=(4,8),这会告知 Ray,如果至少有一个 16 芯片切片(四个工作器)可用,则继续运行,但如果可能,则扩容到两个切片(八个工作器)。它包含一个FailureConfig,用于启用弹性训练、定义重试次数,并帮助确保作业在抢占后继续运行:使用 Ray 作业 CLI 提交作业。请务必提供唯一的
run_name,以免检查点与之前的运行发生冲突。ray job submit \ --address http://localhost:8265 \ --working-dir . \ --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \ -- python maxtext_elastic_trainer.py \ base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=4 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-elastic-8b如需在训练期间模拟节点终止或抢占,请删除 Pod。
kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
终端会记录工作器故障,但编排控制器会保持作业运行,并在最低拓扑可用后自动从 /data/rayjob-elastic-8b/checkpoints 检查点恢复。
由于 MaxText 会在恢复时动态重新计算设备网格,因此您无需编写任何自定义逻辑来处理拓扑缩小时的检查点重新分片。JAX 的 Orbax 检查点程序会在继续执行训练循环之前,自动将保存的权重重新分片到新的物理布局中。以下输出显示了 Ray Train 控制器在训练期间检测到集群中新近可用的 TPU 资源,并执行从一个切片(四个工作器)到两个切片(八个工作器)的伸缩操作。
...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048 20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留该项目但删除各个资源。
删除 RayCluster:
kubectl delete raycluster maxtext-tpu-cluster删除 GKE 集群:
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE删除 Cloud Storage 存储桶:
gsutil rm -r gs://${GS_BUCKET}
后续步骤
- 了解 Ray on Kubernetes。
- 了解如何在 GKE 上使用 TPU 部署 vLLM。
- 详细了解 GKE 中的 TPU。