本教學課程說明如何使用 MaxText、Ray Train 和 Multislice Trillium TPU,在 Google Kubernetes Engine (GKE) 上訓練 Llama 3 70B 等大型語言模型 (LLM)。本教學課程提供完整的端對端逐步說明,從設定必要的次要資料中心網路,到提交並成功執行 32 個實體 TPU 晶片的分散式訓練工作負載,都有詳細的介紹。
本教學課程的適用對象為平台管理員、操作員和 AI 專家,他們想瞭解如何克服記憶體和網路問題,在分散式多主機 TPU 節點上訓練 700 億個參數的模型。
背景
結合 GKE、KubeRay、MaxText 和 TPU,可為大規模模型訓練作業提供強大且可擴充的平台。本節說明本指南中使用的主要技術:
JAX
JAX 是 Python 程式庫,可進行以加速器為導向的陣列運算和程式轉換,並運用 XLA 編譯器建立高度最佳化的程式碼,在加速器上有效率地擴展。
MaxText
MaxText 是高效能的開放原始碼 LLM 架構,專為擴充性和可自訂性而設計。MaxText 以 JAX 為基礎建構,並經過最佳化,可在 Cloud TPU 上有效率地執行。
TPU
Tensor Processing Unit (TPU) 是 Google 打造的特製加速器,可將機器學習工作負載調整到最佳狀態。與一般用途的 CPU 或平行處理 GPU 不同,TPU 專為深度學習基礎的大量矩陣和張量運算而設計,因此能有效執行這項特定工作。TPU 的主要優勢是可大規模提升效能。
本教學課程使用第六代 TPU TPU Trillium,採用多配量部署模式。Cloud TPU Multislice 是指兩個以上的 Cloud TPU 配量透過資料中心網路 (DCN) 通訊。Multislice 提供完整堆疊,能以符合成本效益的方式進行大規模訓練,並近線性擴充至數萬個 TPU 晶片。如要進一步瞭解多配量,請參閱 Cloud TPU 多配量總覽。
KubeRay
KubeRay 是 Kubernetes 運算子,可提供統一的方式,在 Kubernetes 上部署、管理及監控 Ray 應用程式。KubeRay 運算子會透過 Ray on GKE 外掛程式安裝及管理,建議您使用這個外掛程式在 GKE 上部署及管理 Ray 叢集。
GKE Dynamic Resource Allocation Network (DRANET)
GKE DRANET (動態資源分配網路) 這項功能會將高效能網路裝置動態附加至 Pod,略過標準 Kubernetes 網路,並透過 DCN 啟用高效能。
目標
本教學課程說明如何執行下列操作:
- 設定具有兩個多主機 TPU 節點集區的 GKE 叢集。
- 設定次要 DCN,用於跨切片 TPU 通訊。
- 設定 KubeRay,管理分散式訓練環境。
- 使用動態資源分配 (DRA) 部署 RayCluster 自訂資源,以進行網路附件。
- 利用 Ray Train 的 JaxTrainer 建立 Python 訓練指令碼,在 TPU 切片中協調 MaxText 訓練迴圈。
- 執行基準 Llama 3 8B 訓練工作。
- 透過 DCN 上的 2D 分片 (張量平行處理和 FSDP),擴充至 Llama 3 70B。
事前準備
- 登入 Google Cloud 帳戶。如果您是 Google Cloud新手,歡迎 建立帳戶,親自評估產品在實際工作環境中的成效。新客戶還能獲得價值 $300 美元的免費抵免額,可用於執行、測試及部署工作負載。
-
安裝 Google Cloud CLI。
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
選取或建立專案所需的角色
- 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
-
建立專案:如要建立專案,您需要具備專案建立者角色 (
roles/resourcemanager.projectCreator),其中包含resourcemanager.projects.create權限。瞭解如何授予角色。
-
建立 Google Cloud 專案:
gcloud projects create PROJECT_ID
將
PROJECT_ID替換為您要建立的 Google Cloud 專案名稱。 -
選取您建立的 Google Cloud 專案:
gcloud config set project PROJECT_ID
將
PROJECT_ID替換為 Google Cloud 專案名稱。
啟用必要的 API:
啟用 API 時所需的角色
如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (
roles/serviceusage.serviceUsageAdmin),其中包含serviceusage.services.enable權限。瞭解如何授予角色。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
安裝 Google Cloud CLI。
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
選取或建立專案所需的角色
- 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
-
建立專案:如要建立專案,您需要具備專案建立者角色 (
roles/resourcemanager.projectCreator),其中包含resourcemanager.projects.create權限。瞭解如何授予角色。
-
建立 Google Cloud 專案:
gcloud projects create PROJECT_ID
將
PROJECT_ID替換為您要建立的 Google Cloud 專案名稱。 -
選取您建立的 Google Cloud 專案:
gcloud config set project PROJECT_ID
將
PROJECT_ID替換為 Google Cloud 專案名稱。
啟用必要的 API:
啟用 API 時所需的角色
如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (
roles/serviceusage.serviceUsageAdmin),其中包含serviceusage.services.enable權限。瞭解如何授予角色。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
將角色授予使用者帳戶。針對下列每個 IAM 角色,執行一次下列指令:
roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editorgcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
更改下列內容:
PROJECT_ID:專案 ID。USER_IDENTIFIER:使用者帳戶的 ID。 例如:myemail@example.com。ROLE:授予使用者帳戶的 IAM 角色。
- 由於本教學課程使用 TPU Trillium (v6e),請選取有供應情形的區域或可用區。詳情請參閱「Cloud TPU 配額」。
準備環境
在本教學課程中,您將使用 Cloud Shell。Cloud Shell 已預先安裝本教學課程所用的 gcloud、helm 和 kubectl 指令列工具。
在 Google Cloud 主控台視窗頂端,按一下「啟用 Cloud Shell」
按鈕。系統會在Google Cloud 控制台的新頁框中開啟 Cloud Shell 工作階段,並顯示指令列提示。
在終端機中,複製
kubernetes-engine-samples存放區:git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git切換到包含範例檔案的目錄:
cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext建立並啟用 Python 虛擬環境:
python3 -m venv ray-env source ray-env/bin/activate安裝 Ray CLI:
pip install "ray[default]==2.55.0"請設定下列環境變數:
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export CLUSTER_VERSION=1.35.2-gke.1842000更改下列內容:
GS_BUCKET:Cloud Storage bucket 的名稱。KSA_NAME:Kubernetes 服務帳戶的名稱。CLUSTER_NAME:新叢集的名稱。REGION:TPU Trillium 容量所在的區域。ZONE:TPU Trillium 容量所在的可用區。詳情請參閱「GKE 中的 TPU 可用性」。
設定 Cloud TPU 多配量的叢集網路
在多主機 TPU 配量中,TPU 裝置會透過高速晶片間互連網路通訊。不過,執行 Multislice 工作時,TPU Slice 必須透過 DCN 互相通訊。標準 Kubernetes Pod 網路可能會造成這類流量的瓶頸。ct6e-standard-4t 機型由多個實體網路介面卡 (NIC) 支援。為達到最佳效能,請建立兩個額外的 VPC 網路,並使用 GKE DRANET 將這些網路直接連線至 Ray Pod。
建立兩個額外的虛擬私有雲網路,並設定較大的最大訓練單元 (MTU):
gcloud compute networks create ${CLUSTER_NAME}-net-1 \ --subnet-mode=custom \ --mtu=8896 gcloud compute networks create ${CLUSTER_NAME}-net-2 \ --subnet-mode=custom \ --mtu=8896建立專屬子網路:
gcloud compute networks subnets create tpu-subnet-1 \ --network=${CLUSTER_NAME}-net-1 \ --region=${REGION} \ --range=10.50.0.0/16 gcloud compute networks subnets create tpu-subnet-2 \ --network=${CLUSTER_NAME}-net-2 \ --region=${REGION} \ --range=10.60.0.0/16
建立 GKE 叢集
您可以在 GKE Autopilot 或 Standard 叢集的 TPU 上設定 KubeRay。建議您使用 Autopilot 叢集,享受全代管的 Kubernetes 體驗。如要選擇最適合工作負載的 GKE 作業模式,請參閱「關於 GKE 作業模式」。
如要使用 GKE 代管 DRANET,叢集必須使用 1.35.2-gke.1842000 以上版本 (Autopilot 模式),或 1.34.1-gke.1829001 以上版本 (標準模式)。本教學課程使用 1.35.2-gke.1842000 版。
Autopilot
在 Cloud Shell 中執行下列指令:
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION \ --cluster-version=${CLUSTER_VERSION}如要與叢集通訊,請設定
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$REGION
標準
在 Cloud Shell 中執行下列指令,建立啟用 Ray 運算子外掛程式的標準叢集:
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator,GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --enable-dataplane-v2 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE \ --cluster-version=${CLUSTER_VERSION}這個指令也會啟用
GcsFuseCsiDriver,讓 Pod 將 Cloud Storage bucket 掛接為本機檔案系統。建立叢集可能需要幾分鐘的時間。如要與叢集通訊,請設定
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE建立第一個多主機 TPU 配量節點集區,並啟用 GKE DRANET:
gcloud container node-pools create v6e-16-0 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform建立第二個 TPU 配量節點集區:
gcloud container node-pools create v6e-16-1 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform
GKE 會佈建由四個 TPU Trillium (v6e) VM 組成的節點集區,這些 VM 會一併設定為具有 4x4 拓撲的多主機 TPU 配量。這個節點集區已可處理分散式訓練工作負載。
啟用 Ray 運算子的 GKE 叢集會自動在叢集中安裝 KubeRay 和 KubeRay TPU 網頁掛鉤。
設定 Cloud Storage 值區和服務帳戶
為多主機 TPU 節點之間的共用檢查點建立 Cloud Storage bucket。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}如要啟用 Cloud Storage bucket 的存取權,請建立 Kubernetes 服務帳戶:
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}如要啟用 Cloud Storage 值區的存取權,請將必要的 IAM 政策繫結新增至服務帳戶:
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
建立訓練指令碼
maxtext_multi_slice_trainer.py 指令碼使用 Ray Train 的 JaxTrainer,在兩個 TPU 配量中執行分散式 MaxText 訓練工作。這個指令碼會為八個多主機 TPU 工作站設定訓練環境,並在每個 worker 節點上執行 MaxText 訓練工作。train_loop_per_worker 函式會包裝 MaxText 的主要進入點,並使用 Ray 的分散式排程器,在多主機 TPU 切片上執行 MaxText 訓練器:
上述指令碼定義的 JaxTrainer 執行個體會要求八個工作站,以及 4x4 的拓撲。在內部,Ray 會在兩個 TPU 切片之間佈建 SlicePlacementGroup,並確保 Ray Train 工作站會在兩個切片之間以原子方式執行,每個主機有一個工作站。
訓練模型
目前目錄中的
ray-cluster.tpu-multi-slice.yaml資訊清單會定義 RayCluster 自訂資源。這個資訊清單包含 DRANETResourceClaimTemplate,可為 GKE DRANET 和 Multislice 佈建網路裝置:上述 RayCluster 規格會建立一個 TPU 工作站群組,每個副本有八個工作站 (
numOfHosts: 4),並有兩個副本。每個工作站會要求四個 TPU 晶片 (google.com/tpu: "4")。工作站會排定在 TPU Trillium 節點上執行 (tpu-v6e-slice),該節點屬於同一個共置多主機配量。KubeRay 會以不可分割的形式,擴充切片中的所有四個工作站。GKE 會透過異動 Webhook 啟動排程所需的 JAX 環境變數和 Pod 親和性。如要建立 RayCluster,請套用資訊清單:
envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -確認叢集已準備就緒並正在執行:
kubectl get rayclusters maxtext-tpu-cluster畫面會顯示如下的輸出內容:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 8 8 72 1579277216Ki 0 ready 2m11s如要透過 Ray 首節點服務存取 Ray 資訊主頁,請建立連接埠轉送工作階段:
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &確認可從本機環境連上 RayCluster:
ray list nodes --address http://localhost:8265畫面會顯示如下的輸出內容:
ray list nodes --address http://localhost:8265 2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. 2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads. ======== List: 2026-04-21 10:20:20.945431 ======== Stats: ------------------------------ Total: 9 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1 10.68.9.5 False ALIVE 10.68.9.5 CPU: 8.0 ray.io/accelerator-type: TPU-V6E TPU: 4.0 ray.io/node-group: tpu-group accelerator_type:TPU-V6E: 1.0 ray.io/node-id: 4f0e4d742... memory: 186.265 GiB ray.io/tpu-pod-type: v6e-16 node:10.68.9.5: 1.0 ray.io/tpu-slice-name: tpu-group-0 object_store_memory: 186.265 GiB ray.io/tpu-topology: 4x4 tpu-group-0: 1.0 ray.io/tpu-worker-id: '1' ... 6 ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938 10.68.2.9 True ALIVE 10.68.2.9 CPU: 8.0 ray.io/node-group: headgroup memory: 16.000 GiB ray.io/node-id: ce7056807... node:10.68.2.9: 1.0 node:__internal_head__: 1.0 object_store_memory: 4.765 GiB ...下載 MaxText 基本設定檔。訓練指令碼需要這個檔案,才能設定模型的預設超參數:
curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml將 JaxTrainer 指令碼提交至 RayCluster,並確認 RayJob 是否順利完成:
Llama 3 8B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=4 \
max_target_length=4096 \
model_name=llama3-8b \
steps=100 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=4 \
run_name=rayjob-multi-slice
Llama 3 70B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=2 \
max_target_length=4096 \
model_name=llama3-70b \
steps=100 \
ici_tensor_parallelism=4 \
ici_fsdp_parallelism=4 \
dcn_fsdp_parallelism=2 \
dcn_data_parallelism=1 \
remat_policy=full \
run_name=rayjob-multi-slice-70b-fsdp
上述指令會提交 Python 指令碼,該指令碼會呼叫 RayCluster 的 JaxTrainer Ray 程式碼。ray job submit 指令包含一些 MaxText 專屬引數,可傳遞至模型設定。
在終端機中,您應該會看到類似下列內容的 Llama 3 70B 工作輸出內容:
[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]
------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------
在 Spot VM 上執行多切片彈性訓練
使用 TPU 等熱門加速器時,運用 Spot VM 可能會大幅降低成本。不過,Spot VM 可能會意外遭到先占。
Ray Train 支援彈性訓練,因此工作可以動態擴大或縮減參與的 TPU 配量數量,不會發生失敗情形。如果某個分片遭到先占,Ray 會暫停訓練迴圈,等待其餘工作者重組,從最新的 MaxText 檢查點還原,並在較小的資源用量下繼續訓練。
如要啟用彈性訓練,請將 ScalingConfig 中的 num_workers 參數從靜態整數變更為代表 (minimum_workers, maximum_workers) 的元組。此外,請在 FailureConfig(max_failures=3) 中新增 RunConfig,指示 Ray Train 最多重試訓練迴圈 3 次,而不是在工作站遭到搶占時完全失敗。
更新 Ray Train 指令碼
當前目錄中的
maxtext_elastic_trainer.py指令碼會啟用彈性訓練。請注意,這會設定num_workers=(4,8),告知 Ray 至少有一個 16 個晶片的切片 (四個工作人員) 可用時繼續執行,但盡可能擴充至兩個切片 (八個工作人員)。包括FailureConfig,可啟用彈性訓練、定義重試次數,並確保工作在搶占期間存留:使用 Ray Job CLI 提交工作。請務必提供不重複的
run_name,以免檢查點與先前的執行作業發生衝突。ray job submit \ --address http://localhost:8265 \ --working-dir . \ --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \ -- python maxtext_elastic_trainer.py \ base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=4 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-elastic-8b如要在訓練期間模擬節點終止或先占,請刪除 Pod。
kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
終端機會記錄工作站失敗情形,但協調控制器會讓工作保持運作,並在最低拓撲可用時,從 /data/rayjob-elastic-8b/checkpoints 檢查點自動繼續執行。
因為 MaxText 會在恢復時動態重新計算裝置網格,所以您不必編寫任何自訂邏輯,在拓撲縮小時處理檢查點重新分片。JAX 的 Orbax 檢查點會自動將儲存的權重重新分片到新的實體版面配置中,然後繼續訓練迴圈。以下輸出內容顯示 Ray Train 控制器偵測到叢集中新提供的 TPU 資源,並在訓練期間執行從一個切片 (四個工作站) 到兩個切片 (八個工作站) 的擴充作業。
...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048 20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8
清除所用資源
如要避免系統向您的 Google Cloud 帳戶收取本教學課程所用資源的費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。
刪除 RayCluster:
kubectl delete raycluster maxtext-tpu-cluster刪除 GKE 叢集:
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE刪除 Cloud Storage bucket:
gsutil rm -r gs://${GS_BUCKET}
後續步驟
- 瞭解 Kubernetes 上的 Ray。
- 瞭解如何在 GKE 上使用 TPU 提供 vLLM。
- 進一步瞭解 GKE 中的 TPU。