在 GKE 上使用 JAX、Ray Train 和 TPU Trillium 訓練 LLM

本教學課程說明如何使用 MaxTextRay Train 和 TPU,在 Google Kubernetes Engine (GKE) 上訓練 Llama 3 8B 大型語言模型 (LLM)。

本教學課程提供完整的端對端逐步操作說明,從設定必要的雲端基礎架構,到提交及順利在多主機 TPU 上執行訓練工作負載,都會詳細說明。

本教學課程適用於平台管理員和營運人員,以及想要瞭解如何在分散式多主機 TPU 節點上訓練大型模型的資料和 AI 專家。

背景

結合 GKE、KubeRay、MaxText 和 TPU,可為大規模模型訓練作業提供強大且可擴充的平台。本節說明本指南使用的主要技術:

JAX

JAX 是 Python 程式庫,專為加速器導向的陣列運算和程式轉換而設計,適用於高效能數值運算和大規模機器學習。

JAX 提供可擴充的系統,用於轉換 jax.gradjax.jitjax.vmap 等數值函式,並利用 XLA 編譯器建立經過高度最佳化的程式碼,在 GPU 和 TPU 等加速器上有效率地擴充。JAX 的核心功能在於可組合性,使用者可結合這些轉換,建構複雜的高效能數值程式,以供分散式執行。

MaxText

MaxText 是高效能的開放原始碼大型語言模型 (LLM),專為擴充性和自訂性而設計。MaxText 以 JAX 為基礎建構,並經過最佳化,可在 Cloud TPU 和 GPU 上有效率地執行。

TPU

Tensor Processing Unit (TPU) 是 Google 專為機器學習工作負載最佳化而設計的加速器。與一般用途的 CPU 或平行處理 GPU 不同,TPU 專為深度學習基礎的大量矩陣和張量運算而設計,因此能有效執行這項特定工作。TPU 的主要優勢在於大規模效能。

本教學課程使用第六代 TPU「Trillium」。詳情請參閱「使用 TPU Trillium 的優點」。

KubeRay

KubeRay 是 Kubernetes 運算子,可提供統一的方式,在 Kubernetes 上部署、管理及監控 Ray 應用程式。KubeRay 運算子會透過 Ray on GKE 外掛程式安裝及管理,建議您使用這個外掛程式在 GKE 上部署及管理 Ray 叢集。

目標

本教學課程說明如何執行下列操作:

  1. 設定具有多主機 TPU 節點集區的 GKE 叢集。
  2. 設定 KubeRay 管理分散式訓練環境。
  3. 建構包含 MaxText、Ray 和 JAX 依附元件的自訂 Docker 映像檔。
  4. 建立 Python 訓練指令碼,使用 Ray Train 的 JaxTrainer 在 TPU 切片中協調 MaxText 訓練迴圈。
  5. 定義 RayCluster 自訂資源,以佈建具備必要 TPU 資源的頭部和工作節點。
  6. 將訓練工作提交至 RayCluster,並監控進度。
  7. 使用 Cloud Storage 儲存模型檢查點。

事前準備

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • 若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI

  • 執行下列指令,初始化 gcloud CLI:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • 若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI

  • 執行下列指令,初始化 gcloud CLI:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • 由於本教學課程使用 TPU Trillium (v6e),請選取可用的區域或可用區。詳情請參閱「Cloud TPU 配額」。

準備環境

在本教學課程中,您將使用 Cloud Shell。Cloud Shell 已預先安裝本教學課程所用的 gcloudhelmkubectl 指令列工具。

  1. 前往Google Cloud 控制台

  2. 在 Google Cloud 主控台視窗頂端,按一下「啟用 Cloud Shell」「Activate Shell」(啟用 Shell) 按鈕按鈕。

    系統會在Google Cloud 控制台的新頁框中開啟 Cloud Shell 工作階段,並顯示指令列提示。

  3. 建立並啟用 Python 虛擬環境:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. 安裝 Ray CLI 和其他依附元件:

    pip install "ray[default]==2.49.1"
    
  5. 請設定下列環境變數:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    更改下列內容:

    • GS_BUCKET:Cloud Storage bucket 的名稱。
    • KSA_NAME:Kubernetes 服務帳戶的名稱。
    • CLUSTER_NAME:新叢集的名稱。
    • REGION:TPU Trillium 容量可用的區域。
    • ZONE:TPU Trillium 容量所在的可用區。詳情請參閱「GKE 中的 TPU 可用性」。
    • ARTIFACT_REGISTRY:Artifact Registry 存放區的名稱。

建立 GKE 叢集

您可以在 GKE Autopilot 或 Standard 叢集的 TPU 上設定 KubeRay。建議您使用 Autopilot 叢集,享受全代管的 Kubernetes 體驗。如要選擇最適合工作負載的 GKE 作業模式,請參閱「關於 GKE 作業模式」。

Autopilot

  1. 在 Cloud Shell 中執行下列指令:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. 如要與叢集通訊,請設定 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

標準

  1. 在 Cloud Shell 中執行下列指令,建立啟用 Ray 運算子外掛程式的標準叢集:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    這個指令也會啟用 GcsFuseCsiDriver,讓 Pod 將 Cloud Storage 值區掛接為本機檔案系統。建立叢集可能需要幾分鐘的時間。

  2. 如要與叢集通訊,請設定 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. 建立多主機 TPU 配量節點集區:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE 會佈建由四個 TPU Trillium (v6e) VM 組成的節點集區,這些 VM 會一起設定為多主機 TPU 配量,並採用 4x4 拓撲,可供分散式訓練工作負載使用。

啟用 Ray 運算子的 GKE 叢集會自動在叢集中安裝 KubeRay 和 KubeRay TPU 網頁掛鉤

設定 Cloud Storage 值區和服務帳戶

  1. 為多主機 TPU 節點之間的共用檢查點建立 Cloud Storage bucket。

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. 如要啟用 Cloud Storage 值區的存取權,請建立 Kubernetes 服務帳戶:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. 如要啟用 Cloud Storage 值區的存取權,請將必要的 IAM 政策繫結新增至服務帳戶:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

建立訓練指令碼

下列指令碼使用 Ray Train 的 JaxTrainer 執行分散式 MaxText 訓練工作。這個指令碼會為多主機 TPU 節點集區設定訓練環境,並在每個工作站節點上執行 MaxText 訓練工作。train_loop_per_worker 函式會包裝 MaxText 主要進入點,並使用 Ray 的分散式排程器,在多主機 TPU 切片上執行 MaxText 訓練器。

  1. 將下列 Python 指令碼儲存為 maxtext_ray_trainer.py

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. 如要代管自訂映像檔,請建立 Artifact Registry 存放區:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. 如要建構包含 Ray 和 MaxText 依附元件的訓練映像檔,請建立 Dockerfile

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. 建構、標記 Docker 映像檔,並推送至 Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

訓練模型

  1. 將下列範例資訊清單儲存為 maxtext-tpu-cluster.yaml

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    上述 RayCluster 規格會建立 TPU 工作站群組,每個副本有四個工作站 (numOfHosts: 4)。每個工作站都會要求四個 TPU 晶片 (google.com/tpu: "4")。工作站會排定在執行 TPU Trillium (tpu-v6e-slice) 的節點上執行,而該節點是同一個共置多主機配量的一部分。KubeRay 會以不可分割的形式調度所有四個工作站,而 GKE 會透過變異 Webhook 啟動所需的 JAX 環境變數,以及用於排程的 Pod 親和性。

  2. 如要在 YAML 檔案中設定必要值,請使用 envsubst 建立 RayCluster:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. 確認叢集已準備就緒並正在執行:

    kubectl get rayclusters maxtext-tpu-cluster
    

    畫面會顯示如下的輸出內容:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. 如要透過 Ray 首節點服務存取 Ray 資訊主頁,請建立連接埠轉送工作階段:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. 確認可從本機環境連上 RayCluster:

    ray list nodes --address http://localhost:8265
    

    畫面會顯示如下的輸出內容:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. 將 JaxTrainer 指令碼提交至 RayCluster,並確認 RayJob 是否順利完成:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    上述指令會提交 Python 指令碼,該指令碼會將 JaxTrainer Ray 程式碼呼叫至 RayCluster。ray job submit 指令包含一些 MaxText 專屬引數,可傳遞至模型設定。

    終端機應會顯示類似下列內容的輸出:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

清除所用資源

如要避免系統向您的 Google Cloud 帳戶收取本教學課程所用資源的費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。

  1. 刪除 RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. 刪除 GKE 叢集:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. 刪除 Cloud Storage bucket:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. 刪除 Artifact Registry 存放區:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

後續步驟