本教程介绍如何通过 Saxml,在 Google Kubernetes Engine (GKE) 上使用多主机 TPU 切片节点池部署和应用大语言模型 (LLM),以实现高效的可伸缩架构。
背景
Saxml 是一个实验性系统,应用 Paxml、JAX 和 PyTorch 框架。您可以采用这些框架,使用 TPU 来加速数据处理。为了演示 GKE 中 TPU 的部署,本教程应用了 175B LmCloudSpmd175B32Test 测试模型。GKE 分别在两个具有 4x8
拓扑的 v5e TPU 切片节点池上部署此测试模型。
为了正确部署测试模型,根据模型的大小定义了 TPU 拓扑。鉴于 Nx10 亿 16 位模型大约需要 2 倍 (2xN) GB 的内存,因此 175B LmCloudSpmd175B32Test 模型需要大约 350 GB 的内存。TPU v5e 单个 TPU 芯片具有 16 GB。为了支持 350 GB,GKE 需要 21 个 v5e TPU 芯片 (350/16= 21)。根据 TPU 配置的映射,本教程的正确 TPU 配置如下:
- 机器类型:
ct5lp-hightpu-4t
- 拓扑:
4x8
(32 个 TPU 芯片)
在 GKE 中部署 TPU 时,请务必选择正确的 TPU 拓扑来应用模型。如需了解详情,请参阅规划 TPU 配置。
准备环境
在 Google Cloud 控制台中,启动 Cloud Shell 实例:
打开 Cloud Shell设置默认环境变量:
gcloud config set project PROJECT_ID export PROJECT_ID=$(gcloud config get project) export CONTROL_PLANE_LOCATION=CONTROL_PLANE_LOCATION export BUCKET_NAME=PROJECT_ID-gke-bucket
替换以下值:
- PROJECT_ID:您的 Google Cloud 项目 ID。
- CONTROL_PLANE_LOCATION:集群控制平面的 Compute Engine 可用区。 选择可以使用
ct5lp-hightpu-4t
的可用区。
在此命令中,
BUCKET_NAME
会指定用于存储 Saxml 管理员服务器配置的 Google Cloud存储桶的名称。
创建 GKE Standard 集群
使用 Cloud Shell 执行以下操作:
创建使用适用于 GKE 的工作负载身份联合的 Standard 集群:
gcloud container clusters create saxml \ --location=${CONTROL_PLANE_LOCATION} \ --workload-pool=${PROJECT_ID}.svc.id.goog \ --cluster-version=VERSION \ --num-nodes=4
将
VERSION
替换为 GKE 版本号。GKE 在 1.27.2-gke.2100 及更高版本中支持 TPU v5e。如需了解详情,请参阅 GKE 中的 TPU 可用性。集群创建可能需要几分钟的时间。
创建第一个节点池,名为
tpu1
:gcloud container node-pools create tpu1 \ --location=${CONTROL_PLANE_LOCATION} \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=4x8 \ --num-nodes=8 \ --cluster=saxml
--num-nodes
标志的值是通过将 TPU 拓扑除以每个 TPU 切片的 TPU 芯片数量来计算的。在本示例中:(4 * 8)/4。创建第二个节点池,名为
tpu2
:gcloud container node-pools create tpu2 \ --location=${CONTROL_PLANE_LOCATION} \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=4x8 \ --num-nodes=8 \ --cluster=saxml
--num-nodes
标志的值是通过将 TPU 拓扑除以每个 TPU 切片的 TPU 芯片数量来计算的。在本示例中:(4 * 8)/4。
您已创建以下资源:
- 具有四个 CPU 节点的 Standard 集群。
- 两个具有
4x8
拓扑的 v5e TPU 切片节点池。每个节点池代表 8 个 TPU 切片节点,这些节点各自具有 4 个 TPU 芯片。
必须在至少具有 4x8
拓扑切片(32 个 v5e TPU 芯片)的多主机 v5e TPU 切片上应用 175B 模型。
创建 Cloud Storage 存储桶
创建 Cloud Storage 存储桶以存储 Saxml 管理员服务器配置。正在运行的管理员服务器会定期保存其状态和已发布模型的详细信息。
在 Cloud Shell 中,运行以下命令:
gcloud storage buckets create gs://${BUCKET_NAME}
使用适用于 GKE 的工作负载身份联合配置工作负载访问权限
为应用分配 Kubernetes ServiceAccount,并将该 Kubernetes ServiceAccount 配置为充当 IAM 服务账号。
配置
kubectl
以与您的集群通信:gcloud container clusters get-credentials saxml --location=${CONTROL_PLANE_LOCATION}
为您的应用创建 Kubernetes 服务账号:
kubectl create serviceaccount sax-sa --namespace default
为您的应用创建 IAM 服务账号:
gcloud iam service-accounts create sax-iam-sa
为您的 IAM 服务账号添加 IAM 政策绑定,以便对 Cloud Storage 执行读写操作:
gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \ --role roles/storage.admin
通过在两个服务账号之间添加 IAM 政策绑定,允许 Kubernetes 服务账号模拟 IAM 服务账号。此绑定允许 Kubernetes 服务账号充当 IAM 服务账号,以便 Kubernetes 服务账号可以对 Cloud Storage 执行读写操作。
gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/iam.workloadIdentityUser \ --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解。这样,您的示例应用便知道要用于访问 Google Cloud 服务的服务账号。因此,在应用要使用任何标准 Google API 客户端库访问 Google Cloud 服务时,便会使用该 IAM 服务账号。
kubectl annotate serviceaccount sax-sa \ iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
部署 Saxml
在本部分中,您将部署 Saxml 管理服务器和 Saxml 模型服务器。
部署 Saxml 管理服务器
创建以下
sax-admin-server.yaml
清单:将
BUCKET_NAME
替换为您之前创建的 Cloud Storage 存储空间:perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-admin-server.yaml
应用清单:
kubectl apply -f sax-admin-server.yaml
验证管理员服务器 Pod 是否已启动并运行:
kubectl get deployment
输出类似于以下内容:
NAME READY UP-TO-DATE AVAILABLE AGE sax-admin-server 1/1 1 1 52s
部署 Saxml 模型服务器
在多主机 TPU 切片中运行的工作负载要求每个 Pod 都有一个稳定的网络标识符,以发现同一 TPU 切片中的对等方。如需定义这些标识符,请使用 IndexedJob、StatefulSet 及无头 Service 或 JobSet(它会自动为属于 JobSet 的所有作业创建无头 Service)。Jobset 是一种工作负载 API,可让您将一组 Kubernetes Job 作为一个单元进行管理。JobSet 最常见的应用场景是分布式训练,但您也可以使用它来运行批量工作负载。
以下部分介绍如何使用 JobSet 管理多组模型服务器 Pod。
安装 JobSet v0.2.3 或更高版本。
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
将
JOBSET_VERSION
替换为 JobSet 版本。例如v0.2.3
。验证 JobSet 控制器是否在
jobset-system
命名空间中运行:kubectl get pod -n jobset-system
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE jobset-controller-manager-69449d86bc-hp5r6 2/2 Running 0 2m15s
在两个 TPU 切片节点池中部署两个模型服务器。保存以下
sax-model-server-set
清单:将
BUCKET_NAME
替换为您之前创建的 Cloud Storage 存储空间:perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-model-server-set.yaml
在此清单中:
replicas: 2
是作业副本的数量。每个作业代表一个模型服务器。因此,一组 8 个 Pod。parallelism: 8
和completions: 8
等于每个节点池中的节点数量。- 如果有任何 Pod 失败,
backoffLimit: 0
必须为零以将作业标记为失败。 ports.containerPort: 8471
是用于虚拟机通信的默认端口name: MEGASCALE_NUM_SLICES
会取消设置环境变量,因为 GKE 未运行多切片训练。
应用清单:
kubectl apply -f sax-model-server-set.yaml
验证 Saxml 管理服务器和模型服务器 Pod 的状态:
kubectl get pods
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE sax-admin-server-557c85f488-lnd5d 1/1 Running 0 35h sax-model-server-set-sax-model-server-0-0-nj4sm 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-1-sl8w4 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-2-hb4rk 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-3-qv67g 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-4-pzqz6 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-5-nm7mz 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-6-7br2x 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-7-4pw6z 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-0-8mlf5 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-1-h6z6w 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-2-jggtv 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-3-9v8kj 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-4-6vlb2 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-5-h689p 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-6-bgv5k 1/1 Running 0 24m sax-model-server-set-sax-model-server-1-7-cd6gv 1/1 Running 0 24m
在此示例中,有 16 个模型服务器容器:sax-model-server-set-sax-model-server-0-0-nj4sm
和 sax-model-server-set-sax-model-server-1-0-8mlf5
是每个组中的两个主模型服务器。
您的 Saxml 集群有两个模型服务器,分别部署在两个具有 4x8
拓扑的 v5e TPU 切片节点池上。
部署 Saxml HTTP 服务器和负载均衡器
使用以下预构建映像 HTTP 服务器映像。保存以下
sax-http.yaml
清单:将
BUCKET_NAME
替换为您之前创建的 Cloud Storage 存储空间:perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-http.yaml
应用
sax-http.yaml
清单:kubectl apply -f sax-http.yaml
等待 HTTP 服务器容器完成创建:
kubectl get pods
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE sax-admin-server-557c85f488-lnd5d 1/1 Running 0 35h sax-http-65d478d987-6q7zd 1/1 Running 0 24m sax-model-server-set-sax-model-server-0-0-nj4sm 1/1 Running 0 24m ...
等待系统为 Service 分配外部 IP 地址:
kubectl get svc
输出类似于以下内容:
NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE sax-http-lb LoadBalancer 10.48.11.80 10.182.0.87 8888:32674/TCP 7m36s
使用 Saxml
在 v5e TPU 多主机切片中的 Saxml 上加载、部署和应用模型:
加载模型
检索 Saxml 的负载均衡器 IP 地址。
LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}') PORT="8888"
在两个 v5e TPU 切片节点池中加载
LmCloudSpmd175B
测试模型:curl --request POST \ --header "Content-type: application/json" \ -s ${LB_IP}:${PORT}/publish --data \ '{ "model": "/sax/test/spmd", "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test", "checkpoint": "None", "replicas": 2 }'
测试模型没有经过微调的检查点,权重是随机生成的。模型加载最多可能需要 10 分钟。
输出类似于以下内容:
{ "model": "/sax/test/spmd", "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test", "checkpoint": "None", "replicas": 2 }
检查模型就绪情况:
kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
输出类似于以下内容:
... loading completed. Successfully loaded model for key: /sax/test/spmd
模型已完全加载。
获取模型的相关信息:
curl --request GET \ --header "Content-type: application/json" \ -s ${LB_IP}:${PORT}/listcell --data \ '{ "model": "/sax/test/spmd" }'
输出类似于以下内容:
{ "model": "/sax/test/spmd", "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test", "checkpoint": "None", "max_replicas": 2, "active_replicas": 2 }
应用模型
应用提示请求:
curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
"model": "/sax/test/spmd",
"query": "How many days are in a week?"
}'
以下输出显示了模型响应的示例。此响应可能没有意义,因为测试模型具有随机权重。
取消发布模型
运行以下命令以取消发布模型:
curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
"model": "/sax/test/spmd"
}'
输出类似于以下内容:
{
"model": "/sax/test/spmd"
}