Pelatihan multislice dan elastis di TPU menggunakan Ray Train di GKE

Tutorial ini menunjukkan cara melatih model bahasa besar (LLM) seperti Llama 3 70B di Google Kubernetes Engine (GKE) menggunakan MaxText, Ray Train, dan Multislice Trillium TPU. Tutorial ini memberikan panduan lengkap end-to-end, mulai dari mengonfigurasi jaringan pusat data sekunder yang diperlukan hingga mengirimkan dan berhasil menjalankan workload pelatihan terdistribusi di 32 chip TPU fisik.

Tutorial ini ditujukan bagi admin Platform, operator, dan spesialis AI yang ingin mempelajari cara mengatasi tantangan memori dan jaringan dalam melatih model dengan 70 miliar parameter pada slice TPU multi-host terdistribusi.

Latar belakang

Kombinasi GKE, KubeRay, MaxText, dan TPU menyediakan platform yang canggih dan skalabel untuk pelatihan model berskala besar. Bagian ini menjelaskan teknologi utama yang digunakan dalam panduan ini:

JAX

JAX adalah library Python untuk komputasi array dan transformasi program yang berorientasi pada akselerator, yang memanfaatkan compiler XLA untuk membuat kode yang sangat dioptimalkan dan diskalakan secara efisien pada akselerator.

MaxText

MaxText adalah framework LLM open source berperforma tinggi yang dirancang untuk skalabilitas dan kemampuan penyesuaian. MaxText dibangun di atas JAX dan dioptimalkan untuk berjalan secara efisien di Cloud TPU.

TPU

Tensor Processing Unit (TPU) adalah akselerator yang dirancang khusus oleh Google untuk mengoptimalkan workload machine learning. Tidak seperti CPU serbaguna atau GPU pemrosesan paralel, TPU sangat terspesialisasi untuk komputasi matriks dan tensor masif yang menjadi dasar deep learning, sehingga TPU efisien dalam tugas khusus ini. Keuntungan utama TPU adalah performa dalam skala besar.

Tutorial ini menggunakan TPU Trillium, TPU generasi keenam, dalam pola deployment Multislice. Cloud TPU Multislice adalah tempat dua atau lebih slice Cloud TPU berkomunikasi melalui jaringan pusat data (DCN). Multislice memungkinkan pelatihan skala besar, hemat biaya, dan full-stack dengan penskalaan near-linear hingga puluhan ribu chip TPU. Untuk mengetahui informasi selengkapnya tentang Multislice, lihat Ringkasan Cloud TPU Multislice.

KubeRay

KubeRay adalah operator Kubernetes yang menyediakan cara terpadu untuk men-deploy, mengelola, dan memantau aplikasi Ray di Kubernetes. Operator KubeRay diinstal dan dikelola melalui add-on Ray di GKE, yang merupakan cara yang direkomendasikan untuk men-deploy dan mengelola cluster Ray di GKE.

Jaringan Alokasi Resource Dinamis (DRANET) GKE

GKE DRANET (Dynamic Resource Allocation Network) adalah fitur yang secara dinamis melampirkan perangkat jaringan berperforma tinggi ke Pod, melewati jaringan Kubernetes standar, dan memungkinkan performa tinggi melalui DCN.

Tujuan

Tutorial ini menunjukkan kepada Anda cara melakukan hal berikut:

  1. Siapkan cluster GKE dengan dua node pool TPU multi-host.
  2. Konfigurasi DCN sekunder untuk komunikasi TPU lintas slice.
  3. Konfigurasi KubeRay untuk mengelola lingkungan pelatihan terdistribusi.
  4. Deploy resource kustom RayCluster menggunakan Alokasi Resource Dinamis (DRA) untuk lampiran jaringan.
  5. Buat skrip pelatihan Python dengan memanfaatkan JaxTrainer Ray Train untuk mengatur loop pelatihan MaxText di seluruh slice TPU.
  6. Jalankan tugas pelatihan dasar Llama 3 8B.
  7. Menskalakan hingga Llama 3 70B menggunakan sharding 2D (Tensor Parallelism dan FSDP) melalui DCN.

Sebelum memulai

  • Login ke akun Google Cloud Anda. Jika Anda baru menggunakan Google Cloud, buat akun untuk mengevaluasi performa produk kami dalam skenario dunia nyata. Pelanggan baru juga mendapatkan kredit gratis senilai $300 untuk menjalankan, menguji, dan men-deploy workload.
  • Instal Google Cloud CLI.

  • Jika Anda menggunakan penyedia identitas (IdP) eksternal, Anda harus login ke gcloud CLI dengan identitas gabungan Anda terlebih dahulu.

  • Untuk melakukan inisialisasi gcloud CLI, jalankan perintah berikut:

    gcloud init
  • Buat atau pilih Google Cloud project.

    Peran yang diperlukan untuk memilih atau membuat project

    • Pilih project: Memilih project tidak memerlukan peran IAM tertentu—Anda dapat memilih project mana pun yang telah diberi peran.
    • Membuat project: Untuk membuat project, Anda memerlukan peran Pembuat Project (roles/resourcemanager.projectCreator), yang berisi izin resourcemanager.projects.create. Pelajari cara memberikan peran.
    • Buat Google Cloud project:

      gcloud projects create PROJECT_ID

      Ganti PROJECT_ID dengan nama untuk Google Cloud project yang Anda buat.

    • Pilih project Google Cloud yang Anda buat:

      gcloud config set project PROJECT_ID

      Ganti PROJECT_ID dengan nama project Google Cloud Anda.

  • Verifikasi bahwa penagihan diaktifkan untuk project Google Cloud Anda.

  • Aktifkan API yang diperlukan:

    Peran yang diperlukan untuk mengaktifkan API

    Untuk mengaktifkan API, Anda memerlukan peran IAM Service Usage Admin (roles/serviceusage.serviceUsageAdmin), yang berisi izin serviceusage.services.enable. Pelajari cara memberikan peran.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Instal Google Cloud CLI.

  • Jika Anda menggunakan penyedia identitas (IdP) eksternal, Anda harus login ke gcloud CLI dengan identitas gabungan Anda terlebih dahulu.

  • Untuk melakukan inisialisasi gcloud CLI, jalankan perintah berikut:

    gcloud init
  • Buat atau pilih Google Cloud project.

    Peran yang diperlukan untuk memilih atau membuat project

    • Pilih project: Memilih project tidak memerlukan peran IAM tertentu—Anda dapat memilih project mana pun yang telah diberi peran.
    • Membuat project: Untuk membuat project, Anda memerlukan peran Pembuat Project (roles/resourcemanager.projectCreator), yang berisi izin resourcemanager.projects.create. Pelajari cara memberikan peran.
    • Buat Google Cloud project:

      gcloud projects create PROJECT_ID

      Ganti PROJECT_ID dengan nama untuk Google Cloud project yang Anda buat.

    • Pilih project Google Cloud yang Anda buat:

      gcloud config set project PROJECT_ID

      Ganti PROJECT_ID dengan nama project Google Cloud Anda.

  • Verifikasi bahwa penagihan diaktifkan untuk project Google Cloud Anda.

  • Aktifkan API yang diperlukan:

    Peran yang diperlukan untuk mengaktifkan API

    Untuk mengaktifkan API, Anda memerlukan peran IAM Service Usage Admin (roles/serviceusage.serviceUsageAdmin), yang berisi izin serviceusage.services.enable. Pelajari cara memberikan peran.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Memberikan peran ke akun pengguna Anda. Jalankan perintah berikut satu kali untuk setiap peran IAM berikut: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Ganti kode berikut:

    • PROJECT_ID: Project ID Anda.
    • USER_IDENTIFIER: ID untuk akun pengguna Anda. Misalnya, myemail@example.com.
    • ROLE: Peran IAM yang Anda berikan ke akun pengguna Anda.
  • Karena tutorial ini menggunakan TPU Trillium (v6e), pilih region atau zona dengan ketersediaan. Untuk mengetahui informasi selengkapnya, lihat Kuota Cloud TPU.

Menyiapkan lingkungan Anda

Dalam tutorial ini, Anda akan menggunakan Cloud Shell. Cloud Shell telah diinstal sebelumnya dengan alat command line gcloud, helm, dan kubectl yang digunakan dalam tutorial ini.

  1. Buka Google Cloud console.

  2. Di bagian atas jendela konsol Google Cloud , klik tombol Activate Cloud Shell Tombol Aktifkan Shell.

    Sesi Cloud Shell akan terbuka di dalam frame baru di konsolGoogle Cloud dan menampilkan perintah command line.

  3. Di terminal, clone repositori kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Ubah ke direktori yang berisi file contoh:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Membuat dan mengaktifkan lingkungan virtual Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Instal Ray CLI:

    pip install "ray[default]==2.55.0"
    
  7. Tetapkan variabel lingkungan berikut:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Ganti kode berikut:

    • GS_BUCKET: nama bucket Cloud Storage.
    • KSA_NAME: nama Akun Layanan Kubernetes.
    • CLUSTER_NAME: nama cluster baru.
    • REGION: region tempat kapasitas TPU Trillium Anda tersedia.
    • ZONE: zona tempat kapasitas TPU Trillium Anda tersedia. Untuk mengetahui informasi selengkapnya, lihat Ketersediaan TPU di GKE.

Mengonfigurasi jaringan cluster untuk Cloud TPU Multislice

Dalam slice TPU multi-host, perangkat TPU berkomunikasi melalui interkoneksi antar-chip berkecepatan tinggi. Namun, saat menjalankan tugas Multislice, slice TPU harus berkomunikasi satu sama lain melalui DCN. Jaringan Pod Kubernetes standar dapat menyebabkan kemacetan traffic ini. Jenis mesin ct6e-standard-4t didukung oleh beberapa kartu antarmuka jaringan (NIC) fisik. Untuk mendapatkan performa terbaik, Anda membuat dua jaringan VPC tambahan dan menggunakan GKE DRANET untuk menghubungkannya langsung ke Pod Ray.

  1. Buat dua jaringan VPC tambahan dengan unit pelatihan maksimum (MTU) yang besar:

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Buat subnet khusus:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Membuat cluster GKE

Anda dapat mengonfigurasi KubeRay di TPU dalam cluster GKE Autopilot atau Standard. Sebaiknya gunakan cluster Autopilot untuk mendapatkan pengalaman Kubernetes yang terkelola sepenuhnya. Untuk memilih mode operasi GKE yang paling sesuai untuk workload Anda, lihat Tentang mode operasi GKE.

Untuk menggunakan DRANET yang dikelola GKE, cluster Anda harus menggunakan versi 1.35.2-gke.1842000 atau yang lebih baru untuk mode Autopilot, atau 1.34.1-gke.1829001 atau yang lebih baru untuk mode Standard. Tutorial ini menggunakan versi 1.35.2-gke.1842000.

Autopilot

  1. Jalankan perintah berikut di Cloud Shell:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Untuk berkomunikasi dengan cluster Anda, konfigurasi kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standar

  1. Di Cloud Shell, buat cluster Standar yang mengaktifkan add-on operator Ray dengan menjalankan perintah berikut:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Perintah ini juga mengaktifkan GcsFuseCsiDriver, yang memungkinkan Pod memasang bucket Cloud Storage sebagai sistem file lokal. Pembuatan cluster mungkin memerlukan waktu beberapa menit.

  2. Untuk berkomunikasi dengan cluster Anda, konfigurasi kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Buat node pool slice TPU multi-host pertama dengan GKE DRANET yang diaktifkan:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Buat node pool slice TPU kedua:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE menyediakan node pool yang terdiri dari empat VM TPU Trillium (v6e), yang dikonfigurasi bersama sebagai slice TPU multi-host yang memiliki topologi 4x4. Kumpulan node ini siap untuk workload pelatihan terdistribusi.

Cluster GKE yang mendukung operator Ray akan otomatis menginstal KubeRay dan webhook TPU KubeRay di cluster Anda.

Mengonfigurasi bucket Cloud Storage dan akun layanan

  1. Buat bucket Cloud Storage untuk checkpoint bersama antara node TPU multi-host.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Untuk mengaktifkan akses ke bucket Cloud Storage, buat Akun Layanan Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Untuk mengaktifkan akses ke bucket Cloud Storage, tambahkan binding kebijakan IAM yang diperlukan ke akun layanan:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Membuat skrip pelatihan

Skrip maxtext_multi_slice_trainer.py menggunakan JaxTrainer Ray Train untuk menjalankan tugas pelatihan MaxText terdistribusi di dua slice TPU. Skrip mengonfigurasi lingkungan pelatihan untuk delapan worker TPU multi-host dan menjalankan tugas pelatihan MaxText di setiap node worker. Fungsi train_loop_per_worker membungkus titik entri utama MaxText, dan menggunakan penjadwal terdistribusi Ray untuk mengeksekusi pelatih MaxText pada slice TPU multi-host:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

Skrip sebelumnya menentukan instance JaxTrainer yang meminta delapan pekerja dan topologi 4x4. Secara internal, Ray menyediakan SlicePlacementGroup di dua slice TPU dan membantu memastikan bahwa pekerja Ray Train berjalan secara atomik di kedua slice, dengan satu pekerja per host.

Melatih model

  1. Manifes ray-cluster.tpu-multi-slice.yaml di direktori saat ini menentukan resource kustom RayCluster. Manifes ini mencakup DRANET ResourceClaimTemplate untuk menyediakan perangkat jaringan bagi GKE DRANET dan Multislice:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    Spesifikasi RayCluster sebelumnya membuat grup pekerja TPU dengan delapan pekerja (numOfHosts: 4) per replika, dengan dua replika. Setiap pekerja meminta empat chip TPU (google.com/tpu: "4"). Setiap pekerja dijadwalkan di node TPU Trillium (tpu-v6e-slice), yang merupakan bagian dari slice multi-host yang sama dan ditempatkan bersama. KubeRay menskalakan keempat pekerja dalam slice secara atomik. Variabel lingkungan JAX yang diperlukan, serta Afinitas Pod untuk penjadwalan, di-bootstrap oleh GKE melalui webhook mutating.

  2. Untuk membuat RayCluster, terapkan manifes:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Pastikan cluster sudah siap dan berjalan:

    kubectl get rayclusters maxtext-tpu-cluster
    

    Outputnya akan mirip dengan berikut ini:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Untuk mengakses Dasbor Ray melalui layanan head Ray, buat sesi penerusan port:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifikasi bahwa RayCluster dapat dijangkau dari lingkungan lokal Anda:

    ray list nodes --address http://localhost:8265
    

    Outputnya akan mirip dengan berikut ini:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Download file konfigurasi MaxText dasar. File ini diperlukan oleh skrip pelatihan untuk menetapkan hyperparameter default model:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Kirimkan skrip JaxTrainer ke RayCluster dan periksa apakah RayJob berhasil diselesaikan:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

Perintah sebelumnya mengirimkan skrip Python, yang memanggil kode Ray JaxTrainer ke RayCluster. Perintah ray job submit menyertakan beberapa argumen khusus MaxText untuk diteruskan ke konfigurasi model.

Di terminal, Anda akan melihat output yang mirip dengan berikut untuk tugas Llama 3 70B:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Menjalankan pelatihan elastis multislice di Spot VM

Saat menggunakan akselerator yang sangat dicari seperti TPU, penggunaan Spot VM dapat mengurangi biaya secara signifikan. Namun, Spot VM dapat di-preempt secara tiba-tiba.

Ray Train mendukung pelatihan elastis, yang memungkinkan tugas Anda menskalakan jumlah slice TPU yang berpartisipasi secara dinamis ke atas atau ke bawah tanpa gagal. Jika slice di-preempt, Ray akan menjeda loop pelatihan, menunggu worker yang tersisa untuk mengatur ulang, memulihkan dari checkpoint MaxText terbaru, dan melanjutkan pelatihan dengan footprint yang lebih kecil.

Untuk mengaktifkan pelatihan elastis, ubah parameter num_workers di ScalingConfig dari bilangan bulat statis menjadi tuple yang merepresentasikan (minimum_workers, maximum_workers). Selain itu, tambahkan FailureConfig(max_failures=3) ke RunConfig, yang menginstruksikan Ray Train untuk mencoba ulang loop pelatihan hingga 3 kali, bukan menggagalkan seluruh tugas saat pekerja dihentikan sementara.

Memperbarui skrip Ray Train

  1. Skrip maxtext_elastic_trainer.py di direktori saat ini memungkinkan pelatihan elastis. Perhatikan bahwa perintah ini menetapkan num_workers=(4,8), yang memberi tahu Ray untuk melanjutkan jika setidaknya satu slice 16-chip (empat pekerja) tersedia, tetapi untuk menskalakan hingga dua slice (delapan pekerja) jika memungkinkan. Objek ini mencakup FailureConfig untuk mengaktifkan pelatihan elastis, menentukan jumlah percobaan ulang, dan membantu memastikan tugas tetap berjalan setelah dihentikan:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Kirimkan tugas menggunakan Ray Job CLI. Pastikan untuk memberikan run_name yang unik agar titik pemeriksaan tidak bertentangan dengan proses sebelumnya.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Untuk menyimulasikan penghentian atau preemption node selama pelatihan, hapus Pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

Terminal mencatat kegagalan pekerja, tetapi pengontrol orkestrasi mempertahankan tugas tetap aktif dan otomatis dilanjutkan dari titik pemeriksaan /data/rayjob-elastic-8b/checkpoints setelah topologi minimum tersedia.

Karena MaxText menghitung ulang mesh perangkat secara dinamis saat dilanjutkan, Anda tidak perlu menulis logika kustom untuk menangani perubahan partisi ulang titik pemeriksaan saat topologi menyusut. Checkpointer Orbax JAX akan otomatis membagi ulang bobot yang disimpan ke tata letak fisik baru sebelum melanjutkan loop pelatihan. Output berikut menunjukkan pengontrol Ray Train mendeteksi resource TPU yang baru tersedia di cluster dan melakukan operasi penskalaan dari satu slice (empat pekerja) ke dua slice (delapan pekerja) selama pelatihan.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Pembersihan

Agar akun Google Cloud Anda tidak dikenai biaya untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

  1. Hapus RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Hapus cluster GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Hapus bucket Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    

Langkah berikutnya