本教學課程說明如何使用 MaxText、Ray Train 和 TPU,在 Google Kubernetes Engine (GKE) 上訓練 Llama 3 8B 大型語言模型 (LLM)。
本教學課程提供完整的端對端逐步操作說明,從設定必要的雲端基礎架構,到提交及順利在多主機 TPU 上執行訓練工作負載,都會詳細說明。
本教學課程適用於平台管理員和營運人員,以及想要瞭解如何在分散式多主機 TPU 節點上訓練大型模型的資料和 AI 專家。
背景
結合 GKE、KubeRay、MaxText 和 TPU,可為大規模模型訓練作業提供強大且可擴充的平台。本節說明本指南使用的主要技術:
JAX
JAX 是 Python 程式庫,專為加速器導向的陣列運算和程式轉換而設計,適用於高效能數值運算和大規模機器學習。
JAX 提供可擴充的系統,用於轉換 jax.grad、jax.jit 和 jax.vmap 等數值函式,並利用 XLA 編譯器建立經過高度最佳化的程式碼,在 GPU 和 TPU 等加速器上有效率地擴充。JAX 的核心功能在於可組合性,使用者可結合這些轉換,建構複雜的高效能數值程式,以供分散式執行。
MaxText
MaxText 是高效能的開放原始碼大型語言模型 (LLM),專為擴充性和自訂性而設計。MaxText 以 JAX 為基礎建構,並經過最佳化,可在 Cloud TPU 和 GPU 上有效率地執行。
TPU
Tensor Processing Unit (TPU) 是 Google 專為機器學習工作負載最佳化而設計的加速器。與一般用途的 CPU 或平行處理 GPU 不同,TPU 專為深度學習基礎的大量矩陣和張量運算而設計,因此能有效執行這項特定工作。TPU 的主要優勢在於大規模效能。
本教學課程使用第六代 TPU「Trillium」。詳情請參閱「使用 TPU Trillium 的優點」。
KubeRay
KubeRay 是 Kubernetes 運算子,可提供統一的方式,在 Kubernetes 上部署、管理及監控 Ray 應用程式。KubeRay 運算子會透過 Ray on GKE 外掛程式安裝及管理,建議您使用這個外掛程式在 GKE 上部署及管理 Ray 叢集。
目標
本教學課程說明如何執行下列操作:
- 設定具有多主機 TPU 節點集區的 GKE 叢集。
- 設定 KubeRay 管理分散式訓練環境。
- 建構包含 MaxText、Ray 和 JAX 依附元件的自訂 Docker 映像檔。
- 建立 Python 訓練指令碼,使用 Ray Train 的
JaxTrainer在 TPU 切片中協調 MaxText 訓練迴圈。 - 定義
RayCluster自訂資源,以佈建具備必要 TPU 資源的頭部和工作節點。 - 將訓練工作提交至
RayCluster,並監控進度。 - 使用 Cloud Storage 儲存模型檢查點。
事前準備
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
Install the Google Cloud CLI.
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Install the Google Cloud CLI.
-
若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI。
-
執行下列指令,初始化 gcloud CLI:
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Grant roles to your user account. Run the following command once for each of the following IAM roles:
roles/container.admin, roles/iam.serviceAccountAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
Replace the following:
PROJECT_ID: Your project ID.USER_IDENTIFIER: The identifier for your user account. For example,myemail@example.com.ROLE: The IAM role that you grant to your user account.
- 由於本教學課程使用 TPU Trillium (v6e),請選取可用的區域或可用區。詳情請參閱「Cloud TPU 配額」。
準備環境
在本教學課程中,您將使用 Cloud Shell。Cloud Shell 已預先安裝本教學課程所用的 gcloud、helm 和 kubectl 指令列工具。
在 Google Cloud 主控台視窗頂端,按一下「啟用 Cloud Shell」
按鈕。系統會在Google Cloud 控制台的新頁框中開啟 Cloud Shell 工作階段,並顯示指令列提示。
建立並啟用 Python 虛擬環境:
python3 -m venv ray-env source ray-env/bin/activate安裝 Ray CLI 和其他依附元件:
pip install "ray[default]==2.49.1"請設定下列環境變數:
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY更改下列內容:
GS_BUCKET:Cloud Storage bucket 的名稱。KSA_NAME:Kubernetes 服務帳戶的名稱。CLUSTER_NAME:新叢集的名稱。REGION:TPU Trillium 容量可用的區域。ZONE:TPU Trillium 容量所在的可用區。詳情請參閱「GKE 中的 TPU 可用性」。ARTIFACT_REGISTRY:Artifact Registry 存放區的名稱。
建立 GKE 叢集
您可以在 GKE Autopilot 或 Standard 叢集的 TPU 上設定 KubeRay。建議您使用 Autopilot 叢集,享受全代管的 Kubernetes 體驗。如要選擇最適合工作負載的 GKE 作業模式,請參閱「關於 GKE 作業模式」。
Autopilot
在 Cloud Shell 中執行下列指令:
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION如要與叢集通訊,請設定
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE
標準
在 Cloud Shell 中執行下列指令,建立啟用 Ray 運算子外掛程式的標準叢集:
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator \ --addons GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE這個指令也會啟用
GcsFuseCsiDriver,讓 Pod 將 Cloud Storage 值區掛接為本機檔案系統。建立叢集可能需要幾分鐘的時間。如要與叢集通訊,請設定
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=LOCATION建立多主機 TPU 配量節點集區:
gcloud container node-pools create v6e-16 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4
GKE 會佈建由四個 TPU Trillium (v6e) VM 組成的節點集區,這些 VM 會一起設定為多主機 TPU 配量,並採用 4x4 拓撲,可供分散式訓練工作負載使用。
啟用 Ray 運算子的 GKE 叢集會自動在叢集中安裝 KubeRay 和 KubeRay TPU 網頁掛鉤。
設定 Cloud Storage 值區和服務帳戶
為多主機 TPU 節點之間的共用檢查點建立 Cloud Storage bucket。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}如要啟用 Cloud Storage 值區的存取權,請建立 Kubernetes 服務帳戶:
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}如要啟用 Cloud Storage 值區的存取權,請將必要的 IAM 政策繫結新增至服務帳戶:
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
建立訓練指令碼
下列指令碼使用 Ray Train 的 JaxTrainer 執行分散式 MaxText 訓練工作。這個指令碼會為多主機 TPU 節點集區設定訓練環境,並在每個工作站節點上執行 MaxText 訓練工作。train_loop_per_worker 函式會包裝 MaxText 主要進入點,並使用 Ray 的分散式排程器,在多主機 TPU 切片上執行 MaxText 訓練器。
將下列 Python 指令碼儲存為
maxtext_ray_trainer.py:如要代管自訂映像檔,請建立 Artifact Registry 存放區:
gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \ --repository-format=docker --location=${REGION} && \ gcloud auth configure-docker ${REGION}-docker.pkg.dev如要建構包含 Ray 和 MaxText 依附元件的訓練映像檔,請建立
Dockerfile:建構、標記 Docker 映像檔,並推送至 Artifact Registry:
export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest gcloud builds submit --tag ${DOCKER_IMAGE}
訓練模型
將下列範例資訊清單儲存為
maxtext-tpu-cluster.yaml:上述 RayCluster 規格會建立 TPU 工作站群組,每個副本有四個工作站 (
numOfHosts: 4)。每個工作站都會要求四個 TPU 晶片 (google.com/tpu: "4")。工作站會排定在執行 TPU Trillium (tpu-v6e-slice) 的節點上執行,而該節點是同一個共置多主機配量的一部分。KubeRay 會以不可分割的形式調度所有四個工作站,而 GKE 會透過變異 Webhook 啟動所需的 JAX 環境變數,以及用於排程的 Pod 親和性。如要在 YAML 檔案中設定必要值,請使用
envsubst建立 RayCluster:envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -確認叢集已準備就緒並正在執行:
kubectl get rayclusters maxtext-tpu-cluster畫面會顯示如下的輸出內容:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 4 4 40 798027216Ki 0 ready 11m如要透過 Ray 首節點服務存取 Ray 資訊主頁,請建立連接埠轉送工作階段:
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &確認可從本機環境連上 RayCluster:
ray list nodes --address http://localhost:8265畫面會顯示如下的輸出內容:
======== List: 2025-09-13 03:53:16.988269 ======== Stats: ------------------------------ Total: 5 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56 10.84.0.9 (...)將 JaxTrainer 指令碼提交至 RayCluster,並確認 RayJob 是否順利完成:
ray job submit \ --address http://localhost:8265 \ -- python /app/maxtext_ray_trainer.py \ /app/maxtext/src/MaxText/configs/base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=1 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-8b-4096-tp4-4x4上述指令會提交 Python 指令碼,該指令碼會將 JaxTrainer Ray 程式碼呼叫至 RayCluster。
ray job submit指令包含一些 MaxText 專屬引數,可傳遞至模型設定。終端機應會顯示類似下列內容的輸出:
(RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster] ------------------------------------------ Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded ------------------------------------------
清除所用資源
如要避免系統向您的 Google Cloud 帳戶收取本教學課程所用資源的費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。
刪除 RayCluster:
kubectl delete raycluster maxtext-tpu-cluster刪除 GKE 叢集:
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE刪除 Cloud Storage bucket:
gsutil rm -r gs://${GS_BUCKET}刪除 Artifact Registry 存放區:
gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
後續步驟
- 瞭解 Kubernetes 上的 Ray。
- 瞭解如何在 GKE 上使用 TPU 提供 vLLM。
- 瞭解如何在 GKE 上使用 TPU 提供 SDXL。
- 進一步瞭解 GKE 中的 TPU。