使用 GKE 上的 Ray Train,在 TPU 上進行多切片和彈性訓練

本教學課程說明如何使用 MaxTextRay Train 和 Multislice Trillium TPU,在 Google Kubernetes Engine (GKE) 上訓練 Llama 3 70B 等大型語言模型 (LLM)。本教學課程提供完整的端對端逐步說明,從設定必要的次要資料中心網路,到提交並成功執行 32 個實體 TPU 晶片的分散式訓練工作負載,都有詳細的介紹。

本教學課程的適用對象為平台管理員、操作員和 AI 專家,他們想瞭解如何克服記憶體和網路問題,在分散式多主機 TPU 節點上訓練 700 億個參數的模型。

背景

結合 GKE、KubeRay、MaxText 和 TPU,可為大規模模型訓練作業提供強大且可擴充的平台。本節說明本指南中使用的主要技術:

JAX

JAX 是 Python 程式庫,可進行以加速器為導向的陣列運算和程式轉換,並運用 XLA 編譯器建立高度最佳化的程式碼,在加速器上有效率地擴展。

MaxText

MaxText 是高效能的開放原始碼 LLM 架構,專為擴充性和可自訂性而設計。MaxText 以 JAX 為基礎建構,並經過最佳化,可在 Cloud TPU 上有效率地執行。

TPU

Tensor Processing Unit (TPU) 是 Google 打造的特製加速器,可將機器學習工作負載調整到最佳狀態。與一般用途的 CPU 或平行處理 GPU 不同,TPU 專為深度學習基礎的大量矩陣和張量運算而設計,因此能有效執行這項特定工作。TPU 的主要優勢是可大規模提升效能。

本教學課程使用第六代 TPU TPU Trillium,採用多配量部署模式。Cloud TPU Multislice 是指兩個以上的 Cloud TPU 配量透過資料中心網路 (DCN) 通訊。Multislice 提供完整堆疊,能以符合成本效益的方式進行大規模訓練,並近線性擴充至數萬個 TPU 晶片。如要進一步瞭解多配量,請參閱 Cloud TPU 多配量總覽

KubeRay

KubeRay 是 Kubernetes 運算子,可提供統一的方式,在 Kubernetes 上部署、管理及監控 Ray 應用程式。KubeRay 運算子會透過 Ray on GKE 外掛程式安裝及管理,建議您使用這個外掛程式在 GKE 上部署及管理 Ray 叢集。

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (動態資源分配網路) 這項功能會將高效能網路裝置動態附加至 Pod,略過標準 Kubernetes 網路,並透過 DCN 啟用高效能。

目標

本教學課程說明如何執行下列操作:

  1. 設定具有兩個多主機 TPU 節點集區的 GKE 叢集。
  2. 設定次要 DCN,用於跨切片 TPU 通訊。
  3. 設定 KubeRay,管理分散式訓練環境。
  4. 使用動態資源分配 (DRA) 部署 RayCluster 自訂資源,以進行網路附件。
  5. 利用 Ray Train 的 JaxTrainer 建立 Python 訓練指令碼,在 TPU 切片中協調 MaxText 訓練迴圈。
  6. 執行基準 Llama 3 8B 訓練工作。
  7. 透過 DCN 上的 2D 分片 (張量平行處理和 FSDP),擴充至 Llama 3 70B。

事前準備

  • 登入 Google Cloud 帳戶。如果您是 Google Cloud新手,歡迎 建立帳戶,親自評估產品在實際工作環境中的成效。新客戶還能獲得價值 $300 美元的免費抵免額,可用於執行、測試及部署工作負載。
  • 安裝 Google Cloud CLI。

  • 若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI

  • 執行下列指令,初始化 gcloud CLI:

    gcloud init
  • 建立或選取 Google Cloud 專案

    選取或建立專案所需的角色

    • 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
    • 建立專案:如要建立專案,您需要具備專案建立者角色 (roles/resourcemanager.projectCreator),其中包含 resourcemanager.projects.create 權限。瞭解如何授予角色
    • 建立 Google Cloud 專案:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替換為您要建立的 Google Cloud 專案名稱。

    • 選取您建立的 Google Cloud 專案:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替換為 Google Cloud 專案名稱。

  • 確認專案已啟用計費功能 Google Cloud

  • 啟用必要的 API:

    啟用 API 時所需的角色

    如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (roles/serviceusage.serviceUsageAdmin),其中包含 serviceusage.services.enable 權限。瞭解如何授予角色

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • 安裝 Google Cloud CLI。

  • 若您採用的是外部識別資訊提供者 (IdP),請先使用聯合身分登入 gcloud CLI

  • 執行下列指令,初始化 gcloud CLI:

    gcloud init
  • 建立或選取 Google Cloud 專案

    選取或建立專案所需的角色

    • 選取專案:選取專案時,不需要具備特定 IAM 角色,只要您已獲授角色,即可選取任何專案。
    • 建立專案:如要建立專案,您需要具備專案建立者角色 (roles/resourcemanager.projectCreator),其中包含 resourcemanager.projects.create 權限。瞭解如何授予角色
    • 建立 Google Cloud 專案:

      gcloud projects create PROJECT_ID

      PROJECT_ID 替換為您要建立的 Google Cloud 專案名稱。

    • 選取您建立的 Google Cloud 專案:

      gcloud config set project PROJECT_ID

      PROJECT_ID 替換為 Google Cloud 專案名稱。

  • 確認專案已啟用計費功能 Google Cloud

  • 啟用必要的 API:

    啟用 API 時所需的角色

    如要啟用 API,您需要具備服務使用情形管理員 IAM 角色 (roles/serviceusage.serviceUsageAdmin),其中包含 serviceusage.services.enable 權限。瞭解如何授予角色

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • 將角色授予使用者帳戶。針對下列每個 IAM 角色,執行一次下列指令: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    更改下列內容:

    • PROJECT_ID:專案 ID。
    • USER_IDENTIFIER:使用者帳戶的 ID。 例如:myemail@example.com
    • ROLE:授予使用者帳戶的 IAM 角色。
  • 由於本教學課程使用 TPU Trillium (v6e),請選取有供應情形的區域或可用區。詳情請參閱「Cloud TPU 配額」。

準備環境

在本教學課程中,您將使用 Cloud Shell。Cloud Shell 已預先安裝本教學課程所用的 gcloudhelmkubectl 指令列工具。

  1. 前往Google Cloud 控制台

  2. 在 Google Cloud 主控台視窗頂端,按一下「啟用 Cloud Shell」「Activate Shell」(啟用 Shell) 按鈕按鈕。

    系統會在Google Cloud 控制台的新頁框中開啟 Cloud Shell 工作階段,並顯示指令列提示。

  3. 在終端機中,複製 kubernetes-engine-samples 存放區:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. 切換到包含範例檔案的目錄:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. 建立並啟用 Python 虛擬環境:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. 安裝 Ray CLI:

    pip install "ray[default]==2.55.0"
    
  7. 請設定下列環境變數:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    更改下列內容:

    • GS_BUCKET:Cloud Storage bucket 的名稱。
    • KSA_NAME:Kubernetes 服務帳戶的名稱。
    • CLUSTER_NAME:新叢集的名稱。
    • REGION:TPU Trillium 容量所在的區域。
    • ZONE:TPU Trillium 容量所在的可用區。詳情請參閱「GKE 中的 TPU 可用性」。

設定 Cloud TPU 多配量的叢集網路

在多主機 TPU 配量中,TPU 裝置會透過高速晶片間互連網路通訊。不過,執行 Multislice 工作時,TPU Slice 必須透過 DCN 互相通訊。標準 Kubernetes Pod 網路可能會造成這類流量的瓶頸。ct6e-standard-4t 機型由多個實體網路介面卡 (NIC) 支援。為達到最佳效能,請建立兩個額外的 VPC 網路,並使用 GKE DRANET 將這些網路直接連線至 Ray Pod。

  1. 建立兩個額外的虛擬私有雲網路,並設定較大的最大訓練單元 (MTU):

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. 建立專屬子網路:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

建立 GKE 叢集

您可以在 GKE Autopilot 或 Standard 叢集的 TPU 上設定 KubeRay。建議您使用 Autopilot 叢集,享受全代管的 Kubernetes 體驗。如要選擇最適合工作負載的 GKE 作業模式,請參閱「關於 GKE 作業模式」。

如要使用 GKE 代管 DRANET,叢集必須使用 1.35.2-gke.1842000 以上版本 (Autopilot 模式),或 1.34.1-gke.1829001 以上版本 (標準模式)。本教學課程使用 1.35.2-gke.1842000 版。

Autopilot

  1. 在 Cloud Shell 中執行下列指令:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. 如要與叢集通訊,請設定 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

標準

  1. 在 Cloud Shell 中執行下列指令,建立啟用 Ray 運算子外掛程式的標準叢集:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    這個指令也會啟用 GcsFuseCsiDriver,讓 Pod 將 Cloud Storage bucket 掛接為本機檔案系統。建立叢集可能需要幾分鐘的時間。

  2. 如要與叢集通訊,請設定 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. 建立第一個多主機 TPU 配量節點集區,並啟用 GKE DRANET:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. 建立第二個 TPU 配量節點集區:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE 會佈建由四個 TPU Trillium (v6e) VM 組成的節點集區,這些 VM 會一併設定為具有 4x4 拓撲的多主機 TPU 配量。這個節點集區已可處理分散式訓練工作負載。

啟用 Ray 運算子的 GKE 叢集會自動在叢集中安裝 KubeRay 和 KubeRay TPU 網頁掛鉤

設定 Cloud Storage 值區和服務帳戶

  1. 為多主機 TPU 節點之間的共用檢查點建立 Cloud Storage bucket。

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. 如要啟用 Cloud Storage bucket 的存取權,請建立 Kubernetes 服務帳戶:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. 如要啟用 Cloud Storage 值區的存取權,請將必要的 IAM 政策繫結新增至服務帳戶:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

建立訓練指令碼

maxtext_multi_slice_trainer.py 指令碼使用 Ray Train 的 JaxTrainer,在兩個 TPU 配量中執行分散式 MaxText 訓練工作。這個指令碼會為八個多主機 TPU 工作站設定訓練環境,並在每個 worker 節點上執行 MaxText 訓練工作。train_loop_per_worker 函式會包裝 MaxText 的主要進入點,並使用 Ray 的分散式排程器,在多主機 TPU 切片上執行 MaxText 訓練器:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

上述指令碼定義的 JaxTrainer 執行個體會要求八個工作站,以及 4x4 的拓撲。在內部,Ray 會在兩個 TPU 切片之間佈建 SlicePlacementGroup,並確保 Ray Train 工作站會在兩個切片之間以原子方式執行,每個主機有一個工作站。

訓練模型

  1. 目前目錄中的 ray-cluster.tpu-multi-slice.yaml 資訊清單會定義 RayCluster 自訂資源。這個資訊清單包含 DRANET ResourceClaimTemplate,可為 GKE DRANET 和 Multislice 佈建網路裝置:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    上述 RayCluster 規格會建立一個 TPU 工作站群組,每個副本有八個工作站 (numOfHosts: 4),並有兩個副本。每個工作站會要求四個 TPU 晶片 (google.com/tpu: "4")。工作站會排定在 TPU Trillium 節點上執行 (tpu-v6e-slice),該節點屬於同一個共置多主機配量。KubeRay 會以不可分割的形式,擴充切片中的所有四個工作站。GKE 會透過異動 Webhook 啟動排程所需的 JAX 環境變數和 Pod 親和性。

  2. 如要建立 RayCluster,請套用資訊清單:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. 確認叢集已準備就緒並正在執行:

    kubectl get rayclusters maxtext-tpu-cluster
    

    畫面會顯示如下的輸出內容:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. 如要透過 Ray 首節點服務存取 Ray 資訊主頁,請建立連接埠轉送工作階段:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. 確認可從本機環境連上 RayCluster:

    ray list nodes --address http://localhost:8265
    

    畫面會顯示如下的輸出內容:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. 下載 MaxText 基本設定檔。訓練指令碼需要這個檔案,才能設定模型的預設超參數:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. 將 JaxTrainer 指令碼提交至 RayCluster,並確認 RayJob 是否順利完成:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

上述指令會提交 Python 指令碼,該指令碼會呼叫 RayCluster 的 JaxTrainer Ray 程式碼。ray job submit 指令包含一些 MaxText 專屬引數,可傳遞至模型設定。

在終端機中,您應該會看到類似下列內容的 Llama 3 70B 工作輸出內容:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

在 Spot VM 上執行多切片彈性訓練

使用 TPU 等熱門加速器時,運用 Spot VM 可能會大幅降低成本。不過,Spot VM 可能會意外遭到先占。

Ray Train 支援彈性訓練,因此工作可以動態擴大或縮減參與的 TPU 配量數量,不會發生失敗情形。如果某個分片遭到先占,Ray 會暫停訓練迴圈,等待其餘工作者重組,從最新的 MaxText 檢查點還原,並在較小的資源用量下繼續訓練。

如要啟用彈性訓練,請將 ScalingConfig 中的 num_workers 參數從靜態整數變更為代表 (minimum_workers, maximum_workers) 的元組。此外,請在 FailureConfig(max_failures=3) 中新增 RunConfig,指示 Ray Train 最多重試訓練迴圈 3 次,而不是在工作站遭到搶占時完全失敗。

更新 Ray Train 指令碼

  1. 當前目錄中的 maxtext_elastic_trainer.py 指令碼會啟用彈性訓練。請注意,這會設定 num_workers=(4,8),告知 Ray 至少有一個 16 個晶片的切片 (四個工作人員) 可用時繼續執行,但盡可能擴充至兩個切片 (八個工作人員)。包括 FailureConfig,可啟用彈性訓練、定義重試次數,並確保工作在搶占期間存留:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. 使用 Ray Job CLI 提交工作。請務必提供不重複的 run_name,以免檢查點與先前的執行作業發生衝突。

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. 如要在訓練期間模擬節點終止或先占,請刪除 Pod。

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

終端機會記錄工作站失敗情形,但協調控制器會讓工作保持運作,並在最低拓撲可用時,從 /data/rayjob-elastic-8b/checkpoints 檢查點自動繼續執行。

因為 MaxText 會在恢復時動態重新計算裝置網格,所以您不必編寫任何自訂邏輯,在拓撲縮小時處理檢查點重新分片。JAX 的 Orbax 檢查點會自動將儲存的權重重新分片到新的實體版面配置中,然後繼續訓練迴圈。以下輸出內容顯示 Ray Train 控制器偵測到叢集中新提供的 TPU 資源,並在訓練期間執行從一個切片 (四個工作站) 到兩個切片 (八個工作站) 的擴充作業。

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

清除所用資源

如要避免系統向您的 Google Cloud 帳戶收取本教學課程所用資源的費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。

  1. 刪除 RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. 刪除 GKE 叢集:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. 刪除 Cloud Storage bucket:

    gsutil rm -r gs://${GS_BUCKET}
    

後續步驟